From 40e35acb09d9f65895e0ba69e80c7cdef6d335d6 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 5 Mar 2024 15:43:46 +0300 Subject: [PATCH 01/57] fix for delta db gen --- generate_delta_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generate_delta_dataset.py b/generate_delta_dataset.py index 3238603..fea48cb 100644 --- a/generate_delta_dataset.py +++ b/generate_delta_dataset.py @@ -75,7 +75,7 @@ def generate_delta_db( dft_data = row.data gfn_data = gfn_db.get(idx).data data = { - "energy": dft_data["energy"] - gfn_data["energy"], + "energy": [dft_data["energy"][0] - gfn_data["energy"]], "forces": dft_data["forces"] - gfn_data["forces"] } odb.write(row, data=data) From b1db069dc90c9f9a3523ea2f83a482465062c54d Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 12:16:58 +0300 Subject: [PATCH 02/57] huge datasets rework --- config/datamodule/nablaDFT_ase.yaml | 3 +- config/datamodule/nablaDFT_ase_test.yaml | 3 +- config/datamodule/nablaDFT_hamiltonian.yaml | 13 + .../datamodule/nablaDFT_hamiltonian_test.yaml | 11 + config/datamodule/nablaDFT_pyg.yaml | 10 +- config/datamodule/nablaDFT_pyg_test.yaml | 9 +- nablaDFT/dataset/__init__.py | 5 +- nablaDFT/dataset/hamiltonian_dataset.py | 1 - nablaDFT/dataset/nablaDFT_dataset.py | 359 +++++------------- nablaDFT/dataset/pyg_datasets.py | 222 +++++++++++ nablaDFT/pipelines.py | 6 +- 11 files changed, 369 insertions(+), 273 deletions(-) create mode 100644 config/datamodule/nablaDFT_hamiltonian.yaml create mode 100644 config/datamodule/nablaDFT_hamiltonian_test.yaml create mode 100644 nablaDFT/dataset/pyg_datasets.py diff --git a/config/datamodule/nablaDFT_ase.yaml b/config/datamodule/nablaDFT_ase.yaml index 7b10ea7..1a426cc 100644 --- a/config/datamodule/nablaDFT_ase.yaml +++ b/config/datamodule/nablaDFT_ase.yaml @@ -1,7 +1,6 @@ # Dataset config for ASE nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.ASENablaDFT -type_of_nn: ASE split: ${job_type} datapath: ./datasets/nablaDFT/train/raw dataset_name: ${dataset_name} diff --git a/config/datamodule/nablaDFT_ase_test.yaml b/config/datamodule/nablaDFT_ase_test.yaml index 3907e44..2b11c2a 100644 --- a/config/datamodule/nablaDFT_ase_test.yaml +++ b/config/datamodule/nablaDFT_ase_test.yaml @@ -1,7 +1,6 @@ # Dataset config for ASE nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.ASENablaDFT -type_of_nn: ASE split: ${job_type} datapath: ./datasets/nablaDFT/test/raw dataset_name: ${dataset_name} diff --git a/config/datamodule/nablaDFT_hamiltonian.yaml b/config/datamodule/nablaDFT_hamiltonian.yaml new file mode 100644 index 0000000..692e582 --- /dev/null +++ b/config/datamodule/nablaDFT_hamiltonian.yaml @@ -0,0 +1,13 @@ +# Dataset config for torch geometric nablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianNablaDFT + +root: ./datasets/nablaDFT/train +dataset_name: ${dataset_name} +train_size: 0.9 +val_size: 0.1 +seed: 23 +# Dataloader args +batch_size: 2 +num_workers: 4 +persistent_workers: True +pin_memory: True diff --git a/config/datamodule/nablaDFT_hamiltonian_test.yaml b/config/datamodule/nablaDFT_hamiltonian_test.yaml new file mode 100644 index 0000000..71ea2ea --- /dev/null +++ b/config/datamodule/nablaDFT_hamiltonian_test.yaml @@ -0,0 +1,11 @@ +# Dataset config for torch geometric nablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianNablaDFT + +root: ./datasets/nablaDFT/test +dataset_name: ${dataset_name} +seed: 23 +# Dataloader args +batch_size: 2 +num_workers: 4 +persistent_workers: True +pin_memory: True \ No newline at end of file diff --git a/config/datamodule/nablaDFT_pyg.yaml b/config/datamodule/nablaDFT_pyg.yaml index 893a8ab..5c31eef 100644 --- a/config/datamodule/nablaDFT_pyg.yaml +++ b/config/datamodule/nablaDFT_pyg.yaml @@ -1,12 +1,14 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.PyGNablaDFTDataModule -type_of_nn: PyG -split: ${job_type} root: ./datasets/nablaDFT/train dataset_name: ${dataset_name} train_size: 0.9 val_size: 0.1 +seed: 23 +# Dataloader args batch_size: 32 num_workers: 8 -seed: 23 +persistent_workers: True +pin_memory: True + diff --git a/config/datamodule/nablaDFT_pyg_test.yaml b/config/datamodule/nablaDFT_pyg_test.yaml index 50ba6e8..5c81a7d 100644 --- a/config/datamodule/nablaDFT_pyg_test.yaml +++ b/config/datamodule/nablaDFT_pyg_test.yaml @@ -1,9 +1,10 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.NablaDFT +_target_: nablaDFT.dataset.PyGNablaDFTDataModule -type_of_nn: PyG -split: ${job_type} root: ./datasets/nablaDFT/test dataset_name: ${dataset_name} +# Dataloader args batch_size: 32 -num_workers: 12 \ No newline at end of file +num_workers: 12 +persistent_workers: True +pin_memory: True \ No newline at end of file diff --git a/nablaDFT/dataset/__init__.py b/nablaDFT/dataset/__init__.py index 04fa2e4..45127f0 100644 --- a/nablaDFT/dataset/__init__.py +++ b/nablaDFT/dataset/__init__.py @@ -1,2 +1,3 @@ -from .nablaDFT_dataset import NablaDFT -from .hamiltonian_dataset import HamiltonianDataset \ No newline at end of file +from .nablaDFT_dataset import PyGHamiltonianNablaDFT, ASENablaDFT, PyGNablaDFTDataModule +from .hamiltonian_dataset import HamiltonianDataset # database interface for Hamiltonian datasets +from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT # PyTorch Geometric interfaces for datasets \ No newline at end of file diff --git a/nablaDFT/dataset/hamiltonian_dataset.py b/nablaDFT/dataset/hamiltonian_dataset.py index 52eafd7..9ca89d7 100644 --- a/nablaDFT/dataset/hamiltonian_dataset.py +++ b/nablaDFT/dataset/hamiltonian_dataset.py @@ -302,7 +302,6 @@ def collate_fn(self, batch, return_filtered=False): tuple((int(z), int(l)) for l in self.database.get_orbitals(z)) ) local_orbitals_number += sum(2 * l + 1 for _, l in local_orbitals[-1]) - # print (local_orbitals_number, orbitals_number, len(local_orbitals), len(orbitals)) if ( orbitals_number + local_orbitals_number > self.max_batch_orbitals or len(local_orbitals) + len(orbitals) > self.max_batch_atoms diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index fcb3b06..4ffed25 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -1,3 +1,4 @@ +"""Module defines Pytorch Lightning DataModule interfaces for various NablaDFT datasets""" import json import os from typing import Optional, List @@ -6,15 +7,15 @@ import numpy as np import torch from ase.db import connect -from torch.utils.data import Subset -from torch_geometric.data.lightning import LightningDataset -from torch_geometric.data import InMemoryDataset, Data +from pytorch_lightning import LightningDataModule +from torch.utils.data import random_split +from torch_geometric.loader import DataLoader from schnetpack.data import AtomsDataFormat, load_dataset import nablaDFT from .atoms_datamodule import AtomsDataModule - from .hamiltonian_dataset import HamiltonianDatabase, HamiltonianDataset +from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT class ASENablaDFT(AtomsDataModule): @@ -91,258 +92,106 @@ def prepare_data(self): self.dataset = load_dataset(self.datapath, self.format) -class HamiltonianNablaDFT(HamiltonianDataset): - def __init__( - self, - datapath="database", - dataset_name="dataset_train_2k", - max_batch_orbitals=1200, - max_batch_atoms=150, - max_squares=4802, - subset=None, - dtype=torch.float32, - ): - self.dtype = dtype - if not os.path.exists(datapath): - os.makedirs(datapath) - with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: - data = json.load(f) - url = data["train_databases"][dataset_name] - filepath = datapath + "/" + dataset_name + ".db" - if not os.path.exists(filepath): - request.urlretrieve(url, filepath) - self.database = HamiltonianDatabase(filepath) - max_orbitals = [] - for z in self.database.Z: - max_orbitals.append( - tuple((int(z), int(l)) for l in self.database.get_orbitals(z)) - ) - max_orbitals = tuple(max_orbitals) - self.max_orbitals = max_orbitals - self.max_batch_orbitals = max_batch_orbitals - self.max_batch_atoms = max_batch_atoms - self.max_squares = max_squares - self.subset = None - if subset: - self.subset = np.load(subset) - +class PyGDataModule(LightningDataModule): + """Parent class which encapsulates PyG dataset for use with Pytorch Lightning Trainer. + In order to add new dataset variant, define children class with setup() method. -class PyGNablaDFT(InMemoryDataset): - """Dataset adapter for ASE2PyG conversion. - Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py + Args: + - root (str): path to directory with r'raw/' subfolder with existing dataset or download location. + - dataset_name (str): split name from links .json file. + - train_size (float): part of dataset used for training, must be in [0, 1]. + - val_size (float): part of dataset used for validation, must be in [0, 1]. + - seed (int): seed number, used for torch.Generator object during train/val split. + - kwargs (Dict): other arguments for dataset. """ - - db_suffix = ".db" - - @property - def raw_file_names(self) -> List[str]: - return [(self.dataset_name + self.db_suffix)] - - @property - def processed_file_names(self) -> str: - return f"{self.dataset_name}_{self.split}.pt" - def __init__( - self, - datapath: str = "database", - dataset_name: str = "dataset_train_2k", - split: str = "train", - transform=None, - pre_transform=None, - ): + self, + root: str, + dataset_name: str, + train_size: float = 0.9, + val_size: float = 0.1, + seed: int = 23, + **kwargs + ) -> None: + super().__init__() + self.dataset_train = None + self.dataset_val = None + self.dataset_test = None + self.dataset_predict = None + + self.root = root self.dataset_name = dataset_name - self.datapath = datapath - self.split = split - self.data_all, self.slices_all = [], [] - self.offsets = [0] - super(PyGNablaDFT, self).__init__(datapath, transform, pre_transform) - - for path in self.processed_paths: - data, slices = torch.load(path) - self.data_all.append(data) - self.slices_all.append(slices) - self.offsets.append( - len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] - ) - - def len(self) -> int: - return sum( - len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all - ) - - def get(self, idx): - data_idx = 0 - while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: - data_idx += 1 - self.data = self.data_all[data_idx] - self.slices = self.slices_all[data_idx] - return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) - - def download(self) -> None: - with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: - data = json.load(f) - url = data[f"{self.split}_databases"][self.dataset_name] - request.urlretrieve(url, self.raw_paths[0]) - - def process(self) -> None: - db = connect(self.raw_paths[0]) - samples = [] - for db_row in db.select(): - z = torch.from_numpy(db_row.numbers).long() - positions = torch.from_numpy(db_row.positions).float() - y = torch.from_numpy(np.array(db_row.data["energy"])).float() - # TODO: temp workaround for dataset w/o forces - forces = db_row.data.get("forces", None) - if forces is not None: - forces = torch.from_numpy(np.array(forces)).float() - samples.append(Data(z=z, pos=positions, y=y, forces=forces)) - - if self.pre_filter is not None: - samples = [data for data in samples if self.pre_filter(data)] - - if self.pre_transform is not None: - samples = [self.pre_transform(data) for data in samples] - - data, slices = self.collate(samples) - torch.save((data, slices), self.processed_paths[0]) - - -# From https://github.com/torchmd/torchmd-net/blob/72cdc6f077b2b880540126085c3ed59ba1b6d7e0/torchmdnet/utils.py#L54 -def train_val_split(dset_len, train_size, val_size, seed, order=None): - assert (train_size is None) + ( - val_size is None - ) <= 1, "Only one of train_size, val_size is allowed to be None." - is_float = ( - isinstance(train_size, float), - isinstance(val_size, float), - ) - - train_size = round(dset_len * train_size) if is_float[0] else train_size - val_size = round(dset_len * val_size) if is_float[1] else val_size - - if train_size is None: - train_size = dset_len - val_size - elif val_size is None: - val_size = dset_len - train_size - - if train_size + val_size > dset_len: - if is_float[1]: - val_size -= 1 - elif is_float[0]: - train_size -= 1 - - assert train_size >= 0 and val_size >= 0, ( - f"One of training ({train_size}), validation ({val_size})" - f" splits ended up with a negative size." - ) - - total = train_size + val_size - - assert dset_len >= total, ( - f"The dataset ({dset_len}) is smaller than the " - f"combined split sizes ({total})." - ) - if total < dset_len: - print(f"{dset_len - total} samples were excluded from the dataset") - - idxs = np.arange(dset_len, dtype=int) - if order is None: - idxs = np.random.default_rng(seed).permutation(idxs) - - idx_train = idxs[:train_size] - idx_val = idxs[train_size:total] - - if order is not None: - idx_train = [order[i] for i in idx_train] - idx_val = [order[i] for i in idx_val] - - return np.array(idx_train), np.array(idx_val) - - -# From: https://github.com/torchmd/torchmd-net/blob/72cdc6f077b2b880540126085c3ed59ba1b6d7e0/torchmdnet/utils.py#L112 -def make_splits( - dataset_len, - train_size, - val_size, - seed, - filename=None, # path to save split index - splits=None, - order=None, -): - if splits is not None: - splits = np.load(splits) - idx_train = splits["idx_train"] - idx_val = splits["idx_val"] - else: - idx_train, idx_val = train_val_split( - dataset_len, train_size, val_size, seed, order - ) - - if filename is not None: - np.savez(filename, idx_train=idx_train, idx_val=idx_val) - - return ( - torch.from_numpy(idx_train), - torch.from_numpy(idx_val), - ) - - -def get_PyG_nablaDFT_datasets( - root: str, - split: str, - dataset_name: str, - train_size: float = None, - val_size: float = None, - batch_size: int = None, - num_workers: int = None, - seed: int = None, -): - dataset = PyGNablaDFT(root, dataset_name, split=split) - if split == "train": - idx_train, idx_val = make_splits( - len(dataset), - train_size, - val_size, - seed, - filename=os.path.join(root, "splits.npz"), - splits=None, - ) - train_dataset = Subset(dataset, idx_train) - val_dataset = Subset(dataset, idx_val) - test_dataset = None - pred_dataset = None - else: - train_dataset = None - val_dataset = None - if split == "predict": - pred_dataset = dataset - test_dataset = None - else: - test_dataset = dataset - pred_dataset = None - - pl_datamodule = LightningDataset( - train_dataset=train_dataset, - val_dataset=val_dataset, - test_dataset=test_dataset, - pred_dataset=pred_dataset, - batch_size=batch_size, - num_workers=num_workers, - pin_memory=True, - persistent_workers=True, - ) - return pl_datamodule - - -class NablaDFT: - def __init__(self, type_of_nn, *args, **kwargs): - valid = {"ASE", "Hamiltonian", "PyG"} - if type_of_nn not in valid: - raise ValueError("results: type of nn must be one of %r." % valid) - self.type_of_nn = type_of_nn - if self.type_of_nn == "ASE": - self.dataset = ASENablaDFT(*args, **kwargs) - elif self.type_of_nn == "Hamiltonian": - self.dataset = HamiltonianNablaDFT(*args, **kwargs) - else: - self.dataset = get_PyG_nablaDFT_datasets(*args, **kwargs) + self.seed = seed + self.sizes = [train_size, val_size] + dataloader_keys = [ + "batch_size", "num_workers", + "pin_memory", "persistent_workers" + ] + self.dataloader_kwargs = {} + for key in dataloader_keys: + val = kwargs.get(key, None) + self.dataloader_kwargs[key] = val + if val is not None: + del kwargs[key] + self.kwargs = kwargs + + def dataloader(self, dataset, **kwargs): + return DataLoader(dataset, **kwargs) + + def setup(self, stage: str) -> None: + raise NotImplementedError + + def train_dataloader(self): + return self.dataloader(self.dataset_train, shuffle=True, **self.dataloader_kwargs) + + def val_dataloader(self): + return self.dataloader(self.dataset_val, shuffle=False, **self.dataloader_kwargs) + + def test_dataloader(self): + return self.dataloader(self.dataset_test, shuffle=False, **self.dataloader_kwargs) + + def predict_dataloader(self): + return self.dataloader(self.dataset_predict, shuffle=False, **self.dataloader_kwargs) + + +class PyGHamiltonianDataModule(PyGDataModule): + """DataModule for Hamiltonian NablaDFT dataset + + Keyword arguments: + - hamiltonian (bool): retrieve from database molecule's full hamiltonian matrix. True by default. + - include_overlap (bool): retrieve from database molecule's overlab matrix. + - include_core (bool): retrieve from databaes molecule's core hamiltonian matrix. + """ + def __init__( + self, + root: str, + dataset_name: str, + train_size: float = None, + val_size: float = None, + **kwargs) -> None: + super().__init__(root, dataset_name, train_size, val_size, **kwargs) + + def setup(self, stage: str) -> None: + if stage == "fit": + dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) + self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, + generator=torch.Generator().manual_seed(self.seed)) + elif stage == "test": + self.dataset_test = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "test", **self.kwargs) + elif stage == "predict": + self.dataset_predict = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "predict", **self.kwargs) + + +class PyGNablaDFTDataModule(PyGDataModule): + def __init__(self, root: str, dataset_name: str, train_size: float = None, val_size: float = None, **kwargs) -> None: + super().__init__(root, dataset_name, train_size, val_size, **kwargs) + + def setup(self, stage: str) -> None: + if stage == "fit": + dataset = PyGNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) + self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, + generator=torch.Generator().manual_seed(self.seed)) + elif stage == "test": + self.dataset_test = PyGNablaDFT(self.root, self.dataset_name, "test", **self.kwargs) + elif stage == "predict": + self.dataset_predict = PyGNablaDFT(self.root, self.dataset_name, "predict", **self.kwargs) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py new file mode 100644 index 0000000..49ae647 --- /dev/null +++ b/nablaDFT/dataset/pyg_datasets.py @@ -0,0 +1,222 @@ +"""Module describes PyTorch Geometric interfaces for various NablaDFT datasets""" +import json +import os +from typing import List, Callable +from urllib import request as request + +import numpy as np +import torch +from ase.db import connect +from torch_geometric.data import InMemoryDataset, Data + +import nablaDFT +from .hamiltonian_dataset import HamiltonianDatabase + + +class PyGNablaDFT(InMemoryDataset): + """Dataset adapter for ASE2PyG conversion. + Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py + """ + + db_suffix = ".db" + + @property + def raw_file_names(self) -> List[str]: + return [(self.dataset_name + self.db_suffix)] + + @property + def processed_file_names(self) -> str: + return f"{self.dataset_name}_{self.split}.pt" + + def __init__( + self, + datapath: str = "database", + dataset_name: str = "dataset_train_2k", + split: str = "train", + transform: Callable = None, + pre_transform: Callable = None, + ): + self.dataset_name = dataset_name + self.datapath = datapath + self.split = split + self.data_all, self.slices_all = [], [] + self.offsets = [0] + super(PyGNablaDFT, self).__init__(datapath, transform, pre_transform) + + for path in self.processed_paths: + data, slices = torch.load(path) + self.data_all.append(data) + self.slices_all.append(slices) + self.offsets.append( + len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] + ) + + def len(self) -> int: + return sum( + len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all + ) + + def get(self, idx): + data_idx = 0 + while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: + data_idx += 1 + self.data = self.data_all[data_idx] + self.slices = self.slices_all[data_idx] + return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) + + def download(self) -> None: + with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: + data = json.load(f) + url = data[f"{self.split}_databases"][self.dataset_name] + request.urlretrieve(url, self.raw_paths[0]) + + def process(self) -> None: + db = connect(self.raw_paths[0]) + samples = [] + for db_row in db.select(): + z = torch.from_numpy(db_row.numbers).long() + positions = torch.from_numpy(db_row.positions).float() + y = torch.from_numpy(np.array(db_row.data["energy"])).float() + # TODO: temp workaround for dataset w/o forces + forces = db_row.data.get("forces", None) + if forces is not None: + forces = torch.from_numpy(np.array(forces)).float() + samples.append(Data(z=z, pos=positions, y=y, forces=forces)) + + if self.pre_filter is not None: + samples = [data for data in samples if self.pre_filter(data)] + + if self.pre_transform is not None: + samples = [self.pre_transform(data) for data in samples] + + data, slices = self.collate(samples) + torch.save((data, slices), self.processed_paths[0]) + + +# TODO: move this to OnDiskDataset +class PyGHamiltonianNablaDFT(InMemoryDataset): + """Pytorch Geometric dataset for NablaDFT Hamiltonian database. + + Args: + - datapath (str): path to existing dataset directory or location for download. + - dataset_name (str): split name from links .json. + - split (str): type of split, must be one of ['train', 'test', 'predict']. + - include_hamiltonian (bool): if True, retrieves full Hamiltonian matrices from database. + - include_overlap (bool): if True, retrieves overlap matrices from database. + - include_core (bool): if True, retrieves core Hamiltonian matrices from database. + - dtype (torch.dtype): defines torch.dtype for energy, positions, forces tensors. + - transform (Callable): callable data transform, called on every access to element. + - pre_transform (Callable): callable data transform, called during process() for every element. + Note: + Hamiltonian matrix for each molecule has different shape. PyTorch Geometric tries to concatenate + each torch.Tensor in batch, so in order to make batch from data we leave all hamiltonian matrices + in numpy array form. During train, these matrices will be yield as List[np.array]. + """ + db_suffix = ".db" + + @property + def raw_file_names(self) -> List[str]: + return [(self.dataset_name + self.db_suffix)] + + @property + def processed_file_names(self) -> str: + return f"{self.dataset_name}_{self.split}.pt" + + def __init__( + self, + datapath="database", + dataset_name="dataset_train_2k", + split: str="train", + include_hamiltonian: bool = True, + include_overlap: bool = False, + include_core: bool = False, + dtype=torch.float32, + transform: Callable = None, + pre_transform: Callable = None, + ): + self.dataset_name = dataset_name + self.datapath = datapath + self.split = split + self.data_all, self.slices_all = [], [] + self.offsets = [0] + self.dtype = dtype + self.max_orbitals = self._get_max_orbitals(datapath, dataset_name) + self.include_hamiltonian = include_hamiltonian + self.include_overlap = include_overlap + self.include_core = include_core + + super(PyGHamiltonianNablaDFT, self).__init__(datapath, transform, pre_transform) + + for path in self.processed_paths: + data, slices = torch.load(path) + self.data_all.append(data) + self.slices_all.append(slices) + self.offsets.append( + len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] + ) + + def len(self) -> int: + return sum( + len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all + ) + + def get(self, idx): + data_idx = 0 + while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: + data_idx += 1 + self.data = self.data_all[data_idx] + self.slices = self.slices_all[data_idx] + return super(PyGHamiltonianNablaDFT, self).get(idx - self.offsets[data_idx]) + + def download(self) -> None: + with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: + data = json.load(f) + url = data["train_databases"][self.dataset_name] + request.urlretrieve(url, self.raw_paths[0]) + + def process(self) -> None: + database = HamiltonianDatabase(self.raw_paths[0]) + samples = [] + for idx in range(len(database)): + data = database[idx] + z = torch.tensor(data[0]).long() + positions = torch.tensor(data[1]).to(self.dtype) + # see notes + hamiltonian = data[4] + if self.include_overlap: + overlap = data[5] + else: + overlap = None + if self.include_core: + core = data[6] + else: + core = None + y = torch.from_numpy(data[2]).to(self.dtype) + forces = torch.from_numpy(data[3]).to(self.dtype) + samples.append(Data( + z=z, pos=positions, + y=y, forces=forces, + hamiltonian=hamiltonian, + overlap=overlap, + core=core, + )) + + if self.pre_filter is not None: + samples = [data for data in samples if self.pre_filter(data)] + + if self.pre_transform is not None: + samples = [self.pre_transform(data) for data in samples] + + data, slices = self.collate(samples) + torch.save((data, slices), self.processed_paths[0]) + + def _get_max_orbitals(self, datapath, dataset_name): + db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) + database = HamiltonianDatabase(db_path) + max_orbitals = [] + for z in database.Z: + max_orbitals.append( + tuple((int(z), int(l)) for l in database.get_orbitals(z)) + ) + max_orbitals = tuple(max_orbitals) + return max_orbitals diff --git a/nablaDFT/pipelines.py b/nablaDFT/pipelines.py index dfee265..21a6aed 100644 --- a/nablaDFT/pipelines.py +++ b/nablaDFT/pipelines.py @@ -35,7 +35,7 @@ def predict( pred_path = os.path.join(os.getcwd(), "predictions") os.makedirs(pred_path, exist_ok=True) predictions = trainer.predict( - model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path + model=model, datamodule=datamodule, ckpt_path=ckpt_path ) torch.save(predictions, f"{pred_path}/{config.name}_{config.dataset_name}.pt") @@ -105,9 +105,9 @@ def run(config: DictConfig): # Datamodule datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) if job_type == "train": - trainer.fit(model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path) elif job_type == "test": - trainer.test(model=model, datamodule=datamodule.dataset, ckpt_path=ckpt_path) + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) elif job_type == "predict": predict(trainer, model, datamodule, ckpt_path, config) # Finalize From 28c8b8378039c9f00e2b5ee52ca46479b15cf44a Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 13:12:56 +0300 Subject: [PATCH 03/57] make dataset processing more verbose --- nablaDFT/dataset/nablaDFT_dataset.py | 2 +- nablaDFT/dataset/pyg_datasets.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 4ffed25..02971e5 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -173,7 +173,7 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) + dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs)[:50] # TODO: temporal subset self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed)) elif stage == "test": diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 49ae647..d9548b8 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -1,9 +1,11 @@ """Module describes PyTorch Geometric interfaces for various NablaDFT datasets""" import json import os +import logging from typing import List, Callable from urllib import request as request +from tqdm import tqdm import numpy as np import torch from ase.db import connect @@ -13,6 +15,9 @@ from .hamiltonian_dataset import HamiltonianDatabase +logger = logging.getLogger(__name__) + + class PyGNablaDFT(InMemoryDataset): """Dataset adapter for ASE2PyG conversion. Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py @@ -65,6 +70,7 @@ def get(self, idx): return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) def download(self) -> None: + logger.info(f"Downloading split: {self.dataset_name}") with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: data = json.load(f) url = data[f"{self.split}_databases"][self.dataset_name] @@ -73,7 +79,7 @@ def download(self) -> None: def process(self) -> None: db = connect(self.raw_paths[0]) samples = [] - for db_row in db.select(): + for db_row in tqdm(db.select(), total=len(db)): z = torch.from_numpy(db_row.numbers).long() positions = torch.from_numpy(db_row.positions).float() y = torch.from_numpy(np.array(db_row.data["energy"])).float() @@ -91,6 +97,7 @@ def process(self) -> None: data, slices = self.collate(samples) torch.save((data, slices), self.processed_paths[0]) + logger.info(f"Saved processed dataset: {self.processed_paths[0]}") # TODO: move this to OnDiskDataset @@ -169,6 +176,7 @@ def get(self, idx): return super(PyGHamiltonianNablaDFT, self).get(idx - self.offsets[data_idx]) def download(self) -> None: + logger.info(f"Downloading split: {self.dataset_name}") with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: data = json.load(f) url = data["train_databases"][self.dataset_name] @@ -177,7 +185,7 @@ def download(self) -> None: def process(self) -> None: database = HamiltonianDatabase(self.raw_paths[0]) samples = [] - for idx in range(len(database)): + for idx in tqdm(range(len(database)), total=len(database)): data = database[idx] z = torch.tensor(data[0]).long() positions = torch.tensor(data[1]).to(self.dtype) @@ -209,6 +217,7 @@ def process(self) -> None: data, slices = self.collate(samples) torch.save((data, slices), self.processed_paths[0]) + logger.info(f"Saved processed dataset: {self.processed_paths[0]}") def _get_max_orbitals(self, datapath, dataset_name): db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) From 9c917bbc6c5f58c7c8f405aee1af169e38fdc012 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 15:10:41 +0300 Subject: [PATCH 04/57] make dataset download more verbose --- nablaDFT/dataset/nablaDFT_dataset.py | 3 +-- nablaDFT/dataset/pyg_datasets.py | 14 ++++++++++---- nablaDFT/utils.py | 29 +++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 02971e5..2b710e2 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -1,7 +1,7 @@ """Module defines Pytorch Lightning DataModule interfaces for various NablaDFT datasets""" import json import os -from typing import Optional, List +from typing import Optional from urllib import request as request import numpy as np @@ -14,7 +14,6 @@ import nablaDFT from .atoms_datamodule import AtomsDataModule -from .hamiltonian_dataset import HamiltonianDatabase, HamiltonianDataset from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index d9548b8..95c8e6f 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -13,7 +13,7 @@ import nablaDFT from .hamiltonian_dataset import HamiltonianDatabase - +from nablaDFT.utils import tqdm_download_hook, get_file_size logger = logging.getLogger(__name__) @@ -74,7 +74,9 @@ def download(self) -> None: with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: data = json.load(f) url = data[f"{self.split}_databases"][self.dataset_name] - request.urlretrieve(url, self.raw_paths[0]) + file_size = get_file_size(url) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size) as t: + request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) def process(self) -> None: db = connect(self.raw_paths[0]) @@ -147,13 +149,13 @@ def __init__( self.data_all, self.slices_all = [], [] self.offsets = [0] self.dtype = dtype - self.max_orbitals = self._get_max_orbitals(datapath, dataset_name) self.include_hamiltonian = include_hamiltonian self.include_overlap = include_overlap self.include_core = include_core super(PyGHamiltonianNablaDFT, self).__init__(datapath, transform, pre_transform) + self.max_orbitals = self._get_max_orbitals(datapath, dataset_name) for path in self.processed_paths: data, slices = torch.load(path) self.data_all.append(data) @@ -180,7 +182,9 @@ def download(self) -> None: with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: data = json.load(f) url = data["train_databases"][self.dataset_name] - request.urlretrieve(url, self.raw_paths[0]) + file_size = get_file_size(url) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size) as t: + request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) def process(self) -> None: database = HamiltonianDatabase(self.raw_paths[0]) @@ -221,6 +225,8 @@ def process(self) -> None: def _get_max_orbitals(self, datapath, dataset_name): db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) + if not os.path.exists(db_path): + self.download() database = HamiltonianDatabase(db_path) max_orbitals = [] for z in database.Z: diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index fd165d4..1a95f02 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -7,6 +7,7 @@ import logging import warnings +from tqdm import tqdm import pytorch_lightning as pl from pytorch_lightning import LightningModule from pytorch_lightning.utilities import rank_zero_only @@ -25,6 +26,31 @@ logger = logging.getLogger() +def get_file_size(url: str) -> int: + """Returns file size in bytes""" + req = request.Request(url, method="HEAD") + with request.urlopen(req) as f: + file_size = f.headers.get('Content-Length') + return int(file_size) + +def tqdm_download_hook(t): + """wraps TQDM progress bar instance""" + last_block = [0] + + def update_to(blocks_count: int, block_size: int, total_size: int): + """Adds progress bar for request.urlretrieve() method + Args: + - blocks_count (int): transferred blocks count. + - block_size (int): size of block in bytes. + - total_size (int): size of requested file. + """ + if total_size in (None, -1): + t.total = total_size + displayed = t.update((blocks_count - last_block[0]) * block_size) + last_block[0] = blocks_count + return displayed + return update_to + def seed_everything(seed=42): random.seed(seed) @@ -74,7 +100,8 @@ def download_model(config: DictConfig) -> str: with open(nablaDFT.__path__[0] + "/links/models_checkpoints.json", "r") as f: data = json.load(f) url = data[f"{model_name}"]["dataset_train_100k"] - request.urlretrieve(url, ckpt_path) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + request.urlretrieve(url, ckpt_path, reporthook=tqdm_download_hook(t)) logging.info(f"Downloaded {model_name} 100k checkpoint to {ckpt_path}") return ckpt_path From 5f4de37adaa35918193b7d0aee74e3a82cb1097d Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 15:21:52 +0300 Subject: [PATCH 05/57] fixes --- nablaDFT/dataset/pyg_datasets.py | 6 ++---- nablaDFT/utils.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 95c8e6f..e7d2c9f 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -70,12 +70,11 @@ def get(self, idx): return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) def download(self) -> None: - logger.info(f"Downloading split: {self.dataset_name}") with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: data = json.load(f) url = data[f"{self.split}_databases"][self.dataset_name] file_size = get_file_size(url) - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size) as t: + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading split: {self.dataset_name}") as t: request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) def process(self) -> None: @@ -178,12 +177,11 @@ def get(self, idx): return super(PyGHamiltonianNablaDFT, self).get(idx - self.offsets[data_idx]) def download(self) -> None: - logger.info(f"Downloading split: {self.dataset_name}") with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: data = json.load(f) url = data["train_databases"][self.dataset_name] file_size = get_file_size(url) - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size) as t: + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading split: {self.dataset_name}") as t: request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) def process(self) -> None: diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index 1a95f02..4aa579e 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -100,7 +100,7 @@ def download_model(config: DictConfig) -> str: with open(nablaDFT.__path__[0] + "/links/models_checkpoints.json", "r") as f: data = json.load(f) url = data[f"{model_name}"]["dataset_train_100k"] - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"Downloading {model_name} checkpoint") as t: request.urlretrieve(url, ckpt_path, reporthook=tqdm_download_hook(t)) logging.info(f"Downloaded {model_name} 100k checkpoint to {ckpt_path}") return ckpt_path From 348b6fef15c910026f56fbfcb331c51c79852beb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 19:24:51 +0300 Subject: [PATCH 06/57] fix imports --- config/datamodule/nablaDFT_hamiltonian.yaml | 4 ++-- nablaDFT/dataset/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/datamodule/nablaDFT_hamiltonian.yaml b/config/datamodule/nablaDFT_hamiltonian.yaml index 692e582..989b77f 100644 --- a/config/datamodule/nablaDFT_hamiltonian.yaml +++ b/config/datamodule/nablaDFT_hamiltonian.yaml @@ -1,7 +1,7 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.PyGHamiltonianNablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianDataModule -root: ./datasets/nablaDFT/train +root: ./datasets/nablaDFT/hamiltonian dataset_name: ${dataset_name} train_size: 0.9 val_size: 0.1 diff --git a/nablaDFT/dataset/__init__.py b/nablaDFT/dataset/__init__.py index 45127f0..bde9da1 100644 --- a/nablaDFT/dataset/__init__.py +++ b/nablaDFT/dataset/__init__.py @@ -1,3 +1,3 @@ -from .nablaDFT_dataset import PyGHamiltonianNablaDFT, ASENablaDFT, PyGNablaDFTDataModule +from .nablaDFT_dataset import ASENablaDFT, PyGNablaDFTDataModule, PyGHamiltonianDataModule from .hamiltonian_dataset import HamiltonianDataset # database interface for Hamiltonian datasets from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT # PyTorch Geometric interfaces for datasets \ No newline at end of file From 9fca72d9f11571cf6344d6e12715aee97c824eaa Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 12 Mar 2024 19:27:13 +0300 Subject: [PATCH 07/57] qhnet initial commit --- config/model/qhnet.yaml | 39 +++ config/qhnet.yaml | 26 ++ nablaDFT/qhnet/__init__.py | 1 + nablaDFT/qhnet/layers.py | 645 +++++++++++++++++++++++++++++++++++++ nablaDFT/qhnet/loss.py | 14 + nablaDFT/qhnet/qhnet.py | 478 +++++++++++++++++++++++++++ 6 files changed, 1203 insertions(+) create mode 100644 config/model/qhnet.yaml create mode 100644 config/qhnet.yaml create mode 100644 nablaDFT/qhnet/__init__.py create mode 100644 nablaDFT/qhnet/layers.py create mode 100644 nablaDFT/qhnet/loss.py create mode 100644 nablaDFT/qhnet/qhnet.py diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml new file mode 100644 index 0000000..aa2e282 --- /dev/null +++ b/config/model/qhnet.yaml @@ -0,0 +1,39 @@ +_target_: nablaDFT.qhnet.QHNetLightning + +model_name: "QHNet" +net: + _target_: nablaDFT.qhnet.QHNet + sh_lmax: 4 + hidden_size: 128 + bottle_hidden_size: 32 + num_gnn_layers: 5 + max_radius: 12 + num_nodes: 83 + radius_embed_dim: 32 + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + amsgrad: true + betas: [0.9, 0.95] + lr: 1e-3 + weight_decay: 0 + +lr_scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + factor: 0.8 + patience: 10 + +losses: + hamiltonian: + _target_: nablaDFT.qhnet.loss.MAE_RMSE_Loss +loss_coefs: + hamiltonian: 1.0 + +metric: + _target_: torchmetrics.MultitaskWrapper + _convert_: all + task_metrics: + hamiltonian: + _target_: torchmetrics.MeanAbsoluteError diff --git a/config/qhnet.yaml b/config/qhnet.yaml new file mode 100644 index 0000000..495be9b --- /dev/null +++ b/config/qhnet.yaml @@ -0,0 +1,26 @@ +# Global variables +name: QHNet +dataset_name: dataset_train_2k +max_steps: 1000000 +warmup_steps: 0 +job_type: train +pretrained: False +ckpt_path: null # path to checkpoint for training resume or test run + +# configs +defaults: + - _self_ + - datamodule: nablaDFT_hamiltonian.yaml # dataset config + - model: qhnet.yaml # model config + - callbacks: default.yaml # pl callbacks config + - loggers: wandb.yaml # pl loggers config + - trainer: train.yaml # trainer config + +# need this to set working dir as current dir +hydra: + output_subdir: null + run: + dir: . +original_work_dir: ${hydra:runtime.cwd} + +seed: 23 \ No newline at end of file diff --git a/nablaDFT/qhnet/__init__.py b/nablaDFT/qhnet/__init__.py new file mode 100644 index 0000000..d38b468 --- /dev/null +++ b/nablaDFT/qhnet/__init__.py @@ -0,0 +1 @@ +from .qhnet import QHNet, QHNetLightning \ No newline at end of file diff --git a/nablaDFT/qhnet/layers.py b/nablaDFT/qhnet/layers.py new file mode 100644 index 0000000..16f53f8 --- /dev/null +++ b/nablaDFT/qhnet/layers.py @@ -0,0 +1,645 @@ +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from e3nn import o3 +from torch_scatter import scatter +from e3nn.nn import FullyConnectedNet +from e3nn.o3 import Linear, TensorProduct + + +def prod(x): + """Compute the product of a sequence.""" + out = 1 + for a in x: + out *= a + return out + + +def ShiftedSoftPlus(x): + return torch.nn.functional.softplus(x) - math.log(2.0) + + +def softplus_inverse(x): + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + return x + torch.log(-torch.expm1(-x)) + + +def get_nonlinear(nonlinear: str): + if nonlinear.lower() == 'ssp': + return ShiftedSoftPlus + elif nonlinear.lower() == 'silu': + return F.silu + elif nonlinear.lower() == 'tanh': + return F.tanh + elif nonlinear.lower() == 'abs': + return torch.abs + else: + raise NotImplementedError + + +def get_feasible_irrep(irrep_in1, irrep_in2, cutoff_irrep_out, tp_mode="uvu"): + irrep_mid = [] + instructions = [] + + for i, (_, ir_in) in enumerate(irrep_in1): + for j, (_, ir_edge) in enumerate(irrep_in2): + for ir_out in ir_in * ir_edge: + if ir_out in cutoff_irrep_out: + if (cutoff_irrep_out.count(ir_out), ir_out) not in irrep_mid: + k = len(irrep_mid) + irrep_mid.append((cutoff_irrep_out.count(ir_out), ir_out)) + else: + k = irrep_mid.index((cutoff_irrep_out.count(ir_out), ir_out)) + instructions.append((i, j, k, tp_mode, True)) + + irrep_mid = o3.Irreps(irrep_mid) + normalization_coefficients = [] + for ins in instructions: + ins_dict = { + 'uvw': (irrep_in1[ins[0]].mul * irrep_in2[ins[1]].mul), + 'uvu': irrep_in2[ins[1]].mul, + 'uvv': irrep_in1[ins[0]].mul, + 'uuw': irrep_in1[ins[0]].mul, + 'uuu': 1, + 'uvuv': 1, + 'uvu 0.0: + alpha /= x + normalization_coefficients += [math.sqrt(alpha)] + + irrep_mid, p, _ = irrep_mid.sort() + instructions = [ + (i_in1, i_in2, p[i_out], mode, train, alpha) + for (i_in1, i_in2, i_out, mode, train), alpha + in zip(instructions, normalization_coefficients) + ] + return irrep_mid, instructions + + +def cutoff_function(x, cutoff): + zeros = torch.zeros_like(x) + x_ = torch.where(x < cutoff, x, zeros) + return torch.where(x < cutoff, torch.exp(-x_**2/((cutoff-x_)*(cutoff+x_))), zeros) + + +class ExponentialBernsteinRadialBasisFunctions(nn.Module): + def __init__(self, num_basis_functions, cutoff, ini_alpha=0.5): + super(ExponentialBernsteinRadialBasisFunctions, self).__init__() + self.num_basis_functions = num_basis_functions + self.ini_alpha = ini_alpha + # compute values to initialize buffers + logfactorial = np.zeros((num_basis_functions)) + for i in range(2,num_basis_functions): + logfactorial[i] = logfactorial[i-1] + np.log(i) + v = np.arange(0,num_basis_functions) + n = (num_basis_functions-1)-v + logbinomial = logfactorial[-1]-logfactorial[v]-logfactorial[n] + #register buffers and parameters + self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float32)) + self.register_buffer('logc', torch.tensor(logbinomial, dtype=torch.float32)) + self.register_buffer('n', torch.tensor(n, dtype=torch.float32)) + self.register_buffer('v', torch.tensor(v, dtype=torch.float32)) + self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float32))) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha)) + + def forward(self, r): + alpha = F.softplus(self._alpha) + x = - alpha * r + x = self.logc + self.n * x + self.v * torch.log(- torch.expm1(x) ) + rbf = cutoff_function(r, self.cutoff) * torch.exp(x) + return rbf + + +class NormGate(torch.nn.Module): + def __init__(self, irrep): + super(NormGate, self).__init__() + self.irrep = irrep + self.norm = o3.Norm(self.irrep) + + num_mul, num_mul_wo_0 = 0, 0 + for mul, ir in self.irrep: + num_mul += mul + if ir.l != 0: + num_mul_wo_0 += mul + + self.mul = o3.ElementwiseTensorProduct( + self.irrep[1:], o3.Irreps(f"{num_mul_wo_0}x0e")) + self.fc = nn.Sequential( + nn.Linear(num_mul, num_mul), + nn.SiLU(), + nn.Linear(num_mul, num_mul)) + + self.num_mul = num_mul + self.num_mul_wo_0 = num_mul_wo_0 + + def forward(self, x): + norm_x = self.norm(x)[:, self.irrep.slices()[0].stop:] + f0 = torch.cat([x[:, self.irrep.slices()[0]], norm_x], dim=-1) + gates = self.fc(f0) + gated = self.mul(x[:, self.irrep.slices()[0].stop:], gates[:, self.irrep.slices()[0].stop:]) + x = torch.cat([gates[:, self.irrep.slices()[0]], gated], dim=-1) + return x + + +class ConvLayer(torch.nn.Module): + def __init__( + self, + irrep_in_node, + irrep_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + invariant_layers=1, + invariant_neurons=32, + avg_num_neighbors=None, + nonlinear='ssp', + use_norm_gate=True, + edge_wise=False, + ): + super(ConvLayer, self).__init__() + self.avg_num_neighbors = avg_num_neighbors + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.edge_wise = edge_wise + + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_hidden = irrep_hidden \ + if isinstance(irrep_hidden, o3.Irreps) else o3.Irreps(irrep_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_out_node, instruction_node = get_feasible_irrep( + self.irrep_in_node, self.sh_irrep, self.irrep_hidden, tp_mode='uvu') + + self.tp_node = TensorProduct( + self.irrep_in_node, + self.sh_irrep, + self.irrep_tp_out_node, + instruction_node, + shared_weights=False, + internal_weights=False, + ) + + self.fc_node = FullyConnectedNet( + [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node.weight_numel], + self.nonlinear_layer + ) + + num_mul = 0 + for mul, ir in self.irrep_in_node: + num_mul = num_mul + mul + + self.layer_l0 = FullyConnectedNet( + [num_mul + self.irrep_in_node[0][0]] + invariant_layers * [invariant_neurons] + [self.tp_node.weight_numel], + self.nonlinear_layer + ) + + self.linear_out = Linear( + irreps_in=self.irrep_tp_out_node, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.use_norm_gate = use_norm_gate + self.norm_gate = NormGate(self.irrep_in_node) + self.irrep_linear_out, instruction_node = get_feasible_irrep( + self.irrep_in_node, o3.Irreps("0e"), self.irrep_in_node) + self.linear_node = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pre = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.inner_product = InnerProduct(self.irrep_in_node) + + def forward(self, data, x): + edge_dst, edge_src = data.edge_index[0], data.edge_index[1] + + if self.use_norm_gate: + pre_x = self.linear_node_pre(x) + s0 = self.inner_product(pre_x[edge_dst], pre_x[edge_src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([pre_x[edge_dst][:, self.irrep_in_node.slices()[0]], + pre_x[edge_dst][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + x = self.norm_gate(x) + x = self.linear_node(x) + else: + s0 = self.inner_product(x[edge_dst], x[edge_src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([x[edge_dst][:, self.irrep_in_node.slices()[0]], + x[edge_dst][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + + self_x = x + + edge_features = self.tp_node( + x[edge_src], data.edge_sh, self.fc_node(data.edge_attr) * self.layer_l0(s0)) + + if self.edge_wise: + out = edge_features + else: + out = scatter(edge_features, edge_dst, dim=0, dim_size=len(x)) + + if self.irrep_in_node == self.irrep_out: + out = out + self_x + + out = self.linear_out(out) + return out + + +class InnerProduct(torch.nn.Module): + def __init__(self, irrep_in): + super(InnerProduct, self).__init__() + self.irrep_in = o3.Irreps(irrep_in).simplify() + irrep_out = o3.Irreps([(mul, "0e") for mul, _ in self.irrep_in]) + instr = [(i, i, i, "uuu", False, 1/ir.dim) for i, (mul, ir) in enumerate(self.irrep_in)] + self.tp = o3.TensorProduct(self.irrep_in, self.irrep_in, irrep_out, instr, irrep_normalization="component") + self.irrep_out = irrep_out.simplify() + + def forward(self, features_1, features_2): + out = self.tp(features_1, features_2) + return out + + +class ConvNetLayer(torch.nn.Module): + def __init__( + self, + irrep_in_node, + irrep_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + use_norm_gate=True, + edge_wise=False, + ): + super(ConvNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_hidden = irrep_hidden if isinstance(irrep_hidden, o3.Irreps) \ + else o3.Irreps(irrep_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.resnet = resnet and self.irrep_in_node == self.irrep_out + + self.conv = ConvLayer( + irrep_in_node=self.irrep_in_node, + irrep_hidden=self.irrep_hidden, + sh_irrep=self.sh_irrep, + irrep_out=self.irrep_out, + edge_attr_dim=self.edge_attr_dim, + node_attr_dim=self.node_attr_dim, + invariant_layers=1, + invariant_neurons=32, + avg_num_neighbors=None, + nonlinear='ssp', + use_norm_gate=use_norm_gate, + edge_wise=edge_wise + ) + + def forward(self, data, x): + old_x = x + x = self.conv(data, x) + if self.resnet and self.irrep_out == self.irrep_in_node: + x = old_x + x + return x + + +class PairNetLayer(torch.nn.Module): + def __init__(self, + irrep_in_node, + irrep_bottle_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + invariant_layers=1, + invariant_neurons=8, + nonlinear='ssp'): + super(PairNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + self.invariant_layers = invariant_layers + self.invariant_neurons = invariant_neurons + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_bottle_hidden = irrep_bottle_hidden \ + if isinstance(irrep_bottle_hidden, o3.Irreps) else o3.Irreps(irrep_bottle_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + self.sh_irrep = sh_irrep if isinstance(sh_irrep, o3.Irreps) else o3.Irreps(sh_irrep) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_in_node, _ = get_feasible_irrep(self.irrep_in_node, o3.Irreps("0e"), self.irrep_bottle_hidden) + self.irrep_tp_out_node_pair, instruction_node_pair = get_feasible_irrep( + self.irrep_tp_in_node, self.irrep_tp_in_node, self.irrep_bottle_hidden, tp_mode='uuu') + + self.irrep_tp_out_node_pair_msg, instruction_node_pair_msg = get_feasible_irrep( + self.irrep_tp_in_node, self.sh_irrep, self.irrep_bottle_hidden, tp_mode='uvu') + + self.linear_node_pair = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_tp_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.linear_node_pair_n = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pair_inner = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.tp_node_pair = TensorProduct( + self.irrep_tp_in_node, + self.irrep_tp_in_node, + self.irrep_tp_out_node_pair, + instruction_node_pair, + shared_weights=False, + internal_weights=False, + ) + + self.irrep_tp_out_node_pair_2, instruction_node_pair_2 = get_feasible_irrep( + self.irrep_tp_out_node_pair, self.irrep_tp_out_node_pair, self.irrep_bottle_hidden, tp_mode='uuu') + + self.tp_node_pair_2 = TensorProduct( + self.irrep_tp_out_node_pair, + self.irrep_tp_out_node_pair, + self.irrep_tp_out_node_pair_2, + instruction_node_pair_2, + shared_weights=True, + internal_weights=True + ) + + + self.fc_node_pair = FullyConnectedNet( + [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node_pair.weight_numel], + self.nonlinear_layer + ) + + self.linear_node_pair_2 = Linear( + irreps_in=self.irrep_tp_out_node_pair_2, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + + if self.irrep_in_node == self.irrep_out and resnet: + self.resnet = True + else: + self.resnet = False + + self.linear_node_pair = Linear( + irreps_in=self.irrep_tp_out_node_pair, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.norm_gate = NormGate(self.irrep_tp_out_node_pair) + self.inner_product = InnerProduct(self.irrep_in_node) + self.norm = o3.Norm(self.irrep_in_node) + num_mul = 0 + for mul, ir in self.irrep_in_node: + num_mul = num_mul + mul + + self.norm_gate_pre = NormGate(self.irrep_tp_out_node_pair) + self.fc = nn.Sequential( + nn.Linear(self.irrep_in_node[0][0] + num_mul, self.irrep_in_node[0][0]), + nn.SiLU(), + nn.Linear(self.irrep_in_node[0][0], self.tp_node_pair.weight_numel)) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, data, node_attr, node_pair_attr=None): + dst, src = data.full_edge_index + node_attr_0 = self.linear_node_pair_inner(node_attr) + s0 = self.inner_product(node_attr_0[dst], node_attr_0[src])[:, self.irrep_in_node.slices()[0].stop:] + s0 = torch.cat([node_attr_0[dst][:, self.irrep_in_node.slices()[0]], + node_attr_0[src][:, self.irrep_in_node.slices()[0]], s0], dim=-1) + + node_attr = self.norm_gate_pre(node_attr) + node_attr = self.linear_node_pair_n(node_attr) + + node_pair = self.tp_node_pair(node_attr[src], node_attr[dst], + self.fc_node_pair(data.full_edge_attr) * self.fc(s0)) + + node_pair = self.norm_gate(node_pair) + node_pair = self.linear_node_pair(node_pair) + + if self.resnet and node_pair_attr is not None: + node_pair = node_pair + node_pair_attr + return node_pair + + +class SelfNetLayer(torch.nn.Module): + def __init__(self, + irrep_in_node, + irrep_bottle_hidden, + irrep_out, + sh_irrep, + edge_attr_dim, + node_attr_dim, + resnet: bool = True, + nonlinear='ssp'): + super(SelfNetLayer, self).__init__() + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + self.sh_irrep = sh_irrep + self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) + self.irrep_bottle_hidden = irrep_bottle_hidden \ + if isinstance(irrep_bottle_hidden, o3.Irreps) else o3.Irreps(irrep_bottle_hidden) + self.irrep_out = irrep_out if isinstance(irrep_out, o3.Irreps) else o3.Irreps(irrep_out) + + self.edge_attr_dim = edge_attr_dim + self.node_attr_dim = node_attr_dim + self.resnet = resnet + self.nonlinear_layer = get_nonlinear(nonlinear) + + self.irrep_tp_in_node, _ = get_feasible_irrep(self.irrep_in_node, o3.Irreps("0e"), self.irrep_bottle_hidden) + self.irrep_tp_out_node, instruction_node = get_feasible_irrep( + self.irrep_tp_in_node, self.irrep_tp_in_node, self.irrep_bottle_hidden, tp_mode='uuu') + + # - Build modules - + self.linear_node_1 = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + + self.linear_node_2 = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_in_node, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.tp = TensorProduct( + self.irrep_tp_in_node, + self.irrep_tp_in_node, + self.irrep_tp_out_node, + instruction_node, + shared_weights=True, + internal_weights=True + ) + self.norm_gate = NormGate(self.irrep_out) + self.norm_gate_1 = NormGate(self.irrep_in_node) + self.norm_gate_2 = NormGate(self.irrep_in_node) + self.linear_node_3 = Linear( + irreps_in=self.irrep_tp_out_node, + irreps_out=self.irrep_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + + def forward(self, data, x, old_fii): + old_x = x + xl = self.norm_gate_1(x) + xl = self.linear_node_1(xl) + xr = self.norm_gate_2(x) + xr = self.linear_node_2(xr) + x = self.tp(xl, xr) + if self.resnet: + x = x + old_x + x = self.norm_gate(x) + x = self.linear_node_3(x) + if self.resnet and old_fii is not None: + x = old_fii + x + return x + + @property + def device(self): + return next(self.parameters()).device + + +class Expansion(nn.Module): + def __init__(self, irrep_in, irrep_out_1, irrep_out_2): + super(Expansion, self).__init__() + self.irrep_in = irrep_in + self.irrep_out_1 = irrep_out_1 + self.irrep_out_2 = irrep_out_2 + self.instructions = self.get_expansion_path(irrep_in, irrep_out_1, irrep_out_2) + self.num_path_weight = sum(prod(ins[-1]) for ins in self.instructions if ins[3]) + self.num_bias = sum([prod(ins[-1][1:]) for ins in self.instructions if ins[0] == 0]) + if self.num_path_weight > 0: + self.weights = nn.Parameter(torch.rand(self.num_path_weight + self.num_bias)) + self.num_weights = self.num_path_weight + self.num_bias + + def forward(self, x_in, weights=None, bias_weights=None): + batch_num = x_in.shape[0] + if len(self.irrep_in) == 1: + x_in_s = [x_in.reshape(batch_num, self.irrep_in[0].mul, self.irrep_in[0].ir.dim)] + else: + x_in_s = [ + x_in[:, i].reshape(batch_num, mul_ir.mul, mul_ir.ir.dim) + for i, mul_ir in zip(self.irrep_in.slices(), self.irrep_in)] + + outputs = {} + flat_weight_index = 0 + bias_weight_index = 0 + for ins in self.instructions: + mul_ir_in = self.irrep_in[ins[0]] + mul_ir_out1 = self.irrep_out_1[ins[1]] + mul_ir_out2 = self.irrep_out_2[ins[2]] + x1 = x_in_s[ins[0]] + x1 = x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) + w3j_matrix = o3.wigner_3j(ins[1], ins[2], ins[0]).to(self.device).type(x1.type()) + if ins[3] is True or weights is not None: + if weights is None: + weight = self.weights[flat_weight_index:flat_weight_index + prod(ins[-1])].reshape(ins[-1]) + result = torch.einsum( + f"wuv, ijk, bwk-> buivj", weight, w3j_matrix, x1) / mul_ir_in.mul + else: + weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) + result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) + if ins[0] == 0 and bias_weights is not None: + bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ + reshape([-1] + ins[-1][1:]) + bias_weight_index += prod(ins[-1][1:]) + result = result + bias_weight.unsqueeze(-1) + result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul + flat_weight_index += prod(ins[-1]) + else: + result = torch.einsum( + f"uvw, ijk, bwk-> buivj", torch.ones(ins[-1]).type(x1.type()).to(self.device), w3j_matrix, + x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) + ) + result = result.reshape(batch_num, mul_ir_out1.dim, mul_ir_out2.dim) + key = (ins[1], ins[2]) + if key in outputs.keys(): + outputs[key] = outputs[key] + result + else: + outputs[key] = result + + rows = [] + for i in range(len(self.irrep_out_1)): + blocks = [] + for j in range(len(self.irrep_out_2)): + if (i, j) not in outputs.keys(): + blocks += [torch.zeros((x_in.shape[0], self.irrep_out_1[i].dim, self.irrep_out_2[j].dim), + device=x_in.device).type(x_in.type())] + else: + blocks += [outputs[(i, j)]] + rows.append(torch.cat(blocks, dim=-1)) + output = torch.cat(rows, dim=-2) + return output + + def get_expansion_path(self, irrep_in, irrep_out_1, irrep_out_2): + instructions = [] + for i, (num_in, ir_in) in enumerate(irrep_in): + for j, (num_out1, ir_out1) in enumerate(irrep_out_1): + for k, (num_out2, ir_out2) in enumerate(irrep_out_2): + if ir_in in ir_out1 * ir_out2: + instructions.append([i, j, k, True, 1.0, [num_in, num_out1, num_out2]]) + return instructions + + @property + def device(self): + return next(self.parameters()).device + + def __repr__(self): + return f'{self.irrep_in} -> {self.irrep_out_1}x{self.irrep_out_1} and bias {self.num_bias}' \ + f'with parameters {self.num_path_weight}' \ No newline at end of file diff --git a/nablaDFT/qhnet/loss.py b/nablaDFT/qhnet/loss.py new file mode 100644 index 0000000..ada04b4 --- /dev/null +++ b/nablaDFT/qhnet/loss.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MAE_RMSE_Loss(nn.Module): + def __init__(self) -> None: + super(MAE_RMSE_Loss, self).__init__() + + def forward(self, pred, target): + mse = F.mse_loss(pred, target, reduction=None) + mae = F.l1_loss(pred, target, reduction="mean") + rmse = torch.sqrt(mse.mean()) + return mae + rmse diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py new file mode 100644 index 0000000..78e039a --- /dev/null +++ b/nablaDFT/qhnet/qhnet.py @@ -0,0 +1,478 @@ +from typing import Dict + +import torch +from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch_geometric.data import Data +from torch_cluster import radius_graph +from e3nn import o3 +from e3nn.o3 import Linear +import pytorch_lightning as pl + +from .layers import ExponentialBernsteinRadialBasisFunctions, ConvNetLayer, PairNetLayer, SelfNetLayer, Expansion, get_nonlinear + + +ATOM_MASKS = { + 1: [0, 1, 5, 6, 7], + 6: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 7: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 8: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 9: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 16: [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 17: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], + 35: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] +} + + +class QHNet(nn.Module): + def __init__(self, + in_node_features=1, + sh_lmax=4, + hidden_size=128, + bottle_hidden_size=32, + num_gnn_layers=5, + max_radius=12, + num_nodes=10, + radius_embed_dim=32): # maximum nuclear charge (+1, i.e. 87 for up to Rn) for embeddings, can be kept at default + super(QHNet, self).__init__() + # store hyperparameter values + self.atom_orbs = [ + [[8, 0, '1s'], [8, 0, '2s'], [8, 0, '3s'], [8, 1, '2p'], [8, 1, '3p'], [8, 2, '3d']], + [[1, 0, '1s'], [1, 0, '2s'], [1, 1, '2p']], + [[1, 0, '1s'], [1, 0, '2s'], [1, 1, '2p']] + ] + self.order = sh_lmax + self.sh_irrep = o3.Irreps.spherical_harmonics(lmax=self.order) + self.hs = hidden_size + self.hbs = bottle_hidden_size + self.radius_embed_dim = radius_embed_dim + self.max_radius = max_radius + self.num_gnn_layers = num_gnn_layers + self.node_embedding = nn.Embedding(num_nodes, self.hs) + self.hidden_irrep = o3.Irreps(f'{self.hs}x0e + {self.hs}x1o + {self.hs}x2e + {self.hs}x3o + {self.hs}x4e') + self.hidden_bottle_irrep = o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1o + {self.hbs}x2e + {self.hbs}x3o + {self.hbs}x4e') + self.hidden_irrep_base = o3.Irreps(f'{self.hs}x0e + {self.hs}x1e + {self.hs}x2e + {self.hs}x3e + {self.hs}x4e') + self.hidden_bottle_irrep_base = o3.Irreps( + f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e') + self.final_out_irrep = o3.Irreps(f'{self.hs * 3}x0e + {self.hs * 2}x1o + {self.hs}x2e').simplify() + self.input_irrep = o3.Irreps(f'{self.hs}x0e') + self.distance_expansion = ExponentialBernsteinRadialBasisFunctions(self.radius_embed_dim, self.max_radius) + self.nonlinear_scalars = {1: "ssp", -1: "tanh"} + self.nonlinear_gates = {1: "ssp", -1: "abs"} + self.num_fc_layer = 1 + + self.e3_gnn_layer = nn.ModuleList() + self.e3_gnn_node_pair_layer = nn.ModuleList() + self.e3_gnn_node_layer = nn.ModuleList() + self.udpate_layer = nn.ModuleList() + self.start_layer = 2 + for i in range(self.num_gnn_layers): + input_irrep = self.input_irrep if i == 0 else self.hidden_irrep + self.e3_gnn_layer.append(ConvNetLayer( + irrep_in_node=input_irrep, + irrep_hidden=self.hidden_irrep, + irrep_out=self.hidden_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + sh_irrep=self.sh_irrep, + resnet=True, + use_norm_gate=True if i != 0 else False + )) + + if i > self.start_layer: + self.e3_gnn_node_layer.append(SelfNetLayer( + irrep_in_node=self.hidden_irrep_base, + irrep_bottle_hidden=self.hidden_irrep_base, + irrep_out=self.hidden_irrep_base, + sh_irrep=self.sh_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + resnet=True, + )) + + self.e3_gnn_node_pair_layer.append(PairNetLayer( + irrep_in_node=self.hidden_irrep_base, + irrep_bottle_hidden=self.hidden_irrep_base, + irrep_out=self.hidden_irrep_base, + sh_irrep=self.sh_irrep, + edge_attr_dim=self.radius_embed_dim, + node_attr_dim=self.hs, + invariant_layers=self.num_fc_layer, + invariant_neurons=self.hs, + resnet=True, + )) + + self.nonlinear_layer = get_nonlinear('ssp') + self.expand_ii, self.expand_ij, self.fc_ii, self.fc_ij, self.fc_ii_bias, self.fc_ij_bias = \ + nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict(), nn.ModuleDict() + for name in {"hamiltonian"}: + input_expand_ii = o3.Irreps(f"{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e") + + self.expand_ii[name] = Expansion( + input_expand_ii, + o3.Irreps("3x0e + 2x1e + 1x2e"), + o3.Irreps("3x0e + 2x1e + 1x2e") + ) + self.fc_ii[name] = torch.nn.Sequential( + nn.Linear(self.hs, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ii[name].num_path_weight) + ) + self.fc_ii_bias[name] = torch.nn.Sequential( + nn.Linear(self.hs, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ii[name].num_bias) # TODO: this shit defines output dimension for diagonal block + ) + + self.expand_ij[name] = Expansion( + o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e'), + o3.Irreps("3x0e + 2x1e + 1x2e"), + o3.Irreps("3x0e + 2x1e + 1x2e") + ) + + self.fc_ij[name] = torch.nn.Sequential( + nn.Linear(self.hs * 2, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ij[name].num_path_weight) + ) + + self.fc_ij_bias[name] = torch.nn.Sequential( + nn.Linear(self.hs * 2, self.hs), + nn.SiLU(), + nn.Linear(self.hs, self.expand_ij[name].num_bias) + ) + + self.output_ii = Linear(self.hidden_irrep, self.hidden_bottle_irrep) + self.output_ij = Linear(self.hidden_irrep, self.hidden_bottle_irrep) + self.orbital_mask = {} + + def set(self): + for key in ATOM_MASKS.keys(): + self.orbital_mask[key] = torch.tensor(ATOM_MASKS[key]).to(self.device) + + def get_number_of_parameters(self): + num = 0 + for param in self.parameters(): + if param.requires_grad: + num += param.numel() + return num + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, data, keep_blocks=False): + node_attr, edge_index, rbf_new, edge_sh, _ = self.build_graph(data, self.max_radius) + node_attr = self.node_embedding(node_attr) + data.node_attr, data.edge_index, data.edge_attr, data.edge_sh = \ + node_attr, edge_index, rbf_new, edge_sh + + _, full_edge_index, full_edge_attr, full_edge_sh, transpose_edge_index = \ + self.build_graph(data, max_radius=10000) + + data.full_edge_index, data.full_edge_attr, data.full_edge_sh = \ + full_edge_index, full_edge_attr, full_edge_sh + + full_dst, full_src = data.full_edge_index + + fii = None + fij = None + for layer_idx, layer in enumerate(self.e3_gnn_layer): + node_attr = layer(data, node_attr) + if layer_idx > self.start_layer: + fii = self.e3_gnn_node_layer[layer_idx-self.start_layer-1](data, node_attr, fii) + fij = self.e3_gnn_node_pair_layer[layer_idx-self.start_layer-1](data, node_attr, fij) + + fii = self.output_ii(fii) + fij = self.output_ij(fij) + hamiltonian_diagonal_matrix = self.expand_ii['hamiltonian']( + fii, self.fc_ii['hamiltonian'](data.node_attr), self.fc_ii_bias['hamiltonian'](data.node_attr)) + node_pair_embedding = torch.cat([data.node_attr[full_dst], data.node_attr[full_src]], dim=-1) + hamiltonian_non_diagonal_matrix = self.expand_ij['hamiltonian']( + fij, self.fc_ij['hamiltonian'](node_pair_embedding), + self.fc_ij_bias['hamiltonian'](node_pair_embedding)) + + if keep_blocks is False: + hamiltonian_matrix = self.build_final_matrix( + data, hamiltonian_diagonal_matrix, hamiltonian_non_diagonal_matrix) + hamiltonian_matrix = hamiltonian_matrix + hamiltonian_matrix.transpose(-1, -2) + return hamiltonian_matrix + else: + ret_hamiltonian_diagonal_matrix = hamiltonian_diagonal_matrix +\ + hamiltonian_diagonal_matrix.transpose(-1, -2) + + # the transpose should considers the i, j + ret_hamiltonian_non_diagonal_matrix = hamiltonian_non_diagonal_matrix + \ + hamiltonian_non_diagonal_matrix[transpose_edge_index].transpose(-1, -2) + + results = {} + results['hamiltonian_diagonal_blocks'] = ret_hamiltonian_diagonal_matrix + results['hamiltonian_non_diagonal_blocks'] = ret_hamiltonian_non_diagonal_matrix + return results + + def build_graph(self, data, max_radius, edge_index=None): + node_attr = data.z.squeeze() + + if edge_index is None: + radius_edges = radius_graph(data.pos, max_radius, data.batch, max_num_neighbors=data.num_nodes) + else: + radius_edges = data.full_edge_index + + dst, src = radius_edges + edge_vec = data.pos[dst.long()] - data.pos[src.long()] + rbf = self.distance_expansion(edge_vec.norm(dim=-1).unsqueeze(-1)).squeeze().type(data.pos.type()) + + edge_sh = o3.spherical_harmonics( + self.sh_irrep, edge_vec[:, [1, 2, 0]], + normalize=True, normalization='component').type(data.pos.type()) + + start_edge_index = 0 + all_transpose_index = [] + for graph_idx in range(data.ptr.shape[0] - 1): + num_nodes = data.ptr[graph_idx +1] - data.ptr[graph_idx] + graph_edge_index = radius_edges[:, start_edge_index:start_edge_index+num_nodes*(num_nodes-1)] + sub_graph_edge_index = graph_edge_index - data.ptr[graph_idx] + bias = (sub_graph_edge_index[0] < sub_graph_edge_index[1]).type(torch.int) + transpose_index = sub_graph_edge_index[0] * (num_nodes - 1) + sub_graph_edge_index[1] - bias + transpose_index = transpose_index + start_edge_index + all_transpose_index.append(transpose_index) + start_edge_index = start_edge_index + num_nodes*(num_nodes-1) + + return node_attr, radius_edges, rbf, edge_sh, torch.cat(all_transpose_index, dim=-1) + + def build_final_matrix(self, data, diagonal_matrix, non_diagonal_matrix): + # concate the blocks together and then select once. + final_matrix = [] + dst, src = data.full_edge_index + for graph_idx in range(data.ptr.shape[0] - 1): + matrix_block_col = [] + for src_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): + matrix_col = [] + for dst_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): + if src_idx == dst_idx: + matrix_col.append(diagonal_matrix[src_idx].index_select( + -2, self.orbital_mask[data.z[dst_idx].item()]).index_select( + -1, self.orbital_mask[data.z[src_idx].item()]) + ) + else: + mask1 = (src == src_idx) + mask2 = (dst == dst_idx) + index = torch.where(mask1 & mask2)[0].item() + + matrix_col.append( + non_diagonal_matrix[index].index_select( + -2, self.orbital_mask[data.z[dst_idx].item()]).index_select( + -1, self.orbital_mask[data.z[src_idx].item()])) + matrix_block_col.append(torch.cat(matrix_col, dim=-2)) + final_matrix.append(torch.cat(matrix_block_col, dim=-1)) + final_matrix = torch.block_diag(final_matrix) + return final_matrix + + def split_matrix(self, data): + diagonal_matrix, non_diagonal_matrix = \ + torch.zeros(data.z.shape[0], 14, 14).type(data.pos.type()).to(self.device), \ + torch.zeros(data.edge_index.shape[1], 14, 14).type(data.pos.type()).to(self.device) + + data.matrix = data.matrix.reshape( + len(data.ptr) - 1, data.matrix.shape[-1], data.matrix.shape[-1]) + + num_atoms = 0 + num_edges = 0 + for graph_idx in range(data.ptr.shape[0] - 1): + slices = [0] + for atom_idx in data.z[range(data.ptr[graph_idx], data.ptr[graph_idx + 1])]: + slices.append(slices[-1] + len(self.orbital_mask[atom_idx.item()])) + + for node_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): + node_idx = node_idx - num_atoms + orb_mask = self.orbital_mask[data.z[node_idx].item()] + diagonal_matrix[node_idx][orb_mask][:, orb_mask] = \ + data.matrix[graph_idx][slices[node_idx]: slices[node_idx+1], slices[node_idx]: slices[node_idx+1]] + + for edge_index_idx in range(num_edges, data.edge_index.shape[1]): + dst, src = data.edge_index[:, edge_index_idx] + if dst > data.ptr[graph_idx + 1] or src > data.ptr[graph_idx + 1]: + break + num_edges = num_edges + 1 + orb_mask_dst = self.orbital_mask[data.z[dst].item()] + orb_mask_src = self.orbital_mask[data.z[src].item()] + graph_dst, graph_src = dst - num_atoms, src - num_atoms + non_diagonal_matrix[edge_index_idx][orb_mask_dst][:, orb_mask_src] = \ + data.matrix[graph_idx][slices[graph_dst]: slices[graph_dst+1], slices[graph_src]: slices[graph_src+1]] + + num_atoms = num_atoms + data.ptr[graph_idx + 1] - data.ptr[graph_idx] + return diagonal_matrix, non_diagonal_matrix + + +class QHNetLightning(pl.LightningModule): + def __init__( + self, + model_name: str, + net: nn.Module, + optimizer: Optimizer, + lr_scheduler: LRScheduler, + losses: Dict, + metric, + loss_coefs, + ) -> None: + super(QHNetLightning, self).__init__() + self.net = net + self.save_hyperparameters(logger=True) + + def forward(self, data: Data): + hamiltonian = self.net(data) + return hamiltonian + + def step(self, batch, calculate_metrics: bool = False): + hamiltonian_out = self.net(batch) + hamiltonian = batch.hamiltonian + preds = {'hamiltonian': hamiltonian_out} + hamiltonian = torch.block_diag(*[torch.from_numpy(H) for H in hamiltonian]) + target = {'hamiltonian': hamiltonian} + loss = self._calculate_loss(preds, target) + if calculate_metrics: + metrics = self._calculate_metrics(preds, target) + return loss, metrics + return loss + + def training_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + loss = self.step(batch, calculate_metrics=False) + self._log_current_lr() + self.log( + "train/loss", + loss, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def validation_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + # with self.ema.average_parameters(): + loss, metrics = self.step(batch, calculate_metrics=True) + self.log( + "val/loss", + loss, + prog_bar=True, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + # workaround for checkpoint callback + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + logger=False, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def test_step(self, batch, batch_idx): + bsz = self._get_batch_size(batch) + with self.ema.average_parameters(): + loss, metrics = self.step(batch, calculate_metrics=True) + self.log( + "test/loss", + loss, + prog_bar=True, + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + batch_size=bsz, + ) + return loss + + def predict_step(self, data): + hamiltonian = self(data) + return hamiltonian + + def configure_optimizers(self): + optimizer = self.hparams.optimizer(params=self.parameters()) + if self.hparams.lr_scheduler is not None: + scheduler = self.hparams.lr_scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + "monitor": "val_loss", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + +# def on_before_zero_grad(self, optimizer: Optimizer) -> None: +# self.ema.update() + + def on_fit_start(self) -> None: + # self._instantiate_ema() + self._check_devices() + + def on_test_start(self) -> None: + # self._instantiate_ema() + self._check_devices() + + def on_validation_epoch_end(self) -> None: + self._reduce_metrics(step_type="val") + + def on_test_epoch_end(self) -> None: + self._reduce_metrics(step_type="test") + + def _calculate_loss(self, y_pred, y_true) -> float: + # Note: since hamiltonians has different shapes, loss calculated per sample + total_loss = 0.0 + for name, loss in self.hparams.losses.items(): + total_loss += self.hparams.loss_coefs[name] * loss( + y_pred[name], y_true[name] + ) + return total_loss + + def _calculate_metrics(self, y_pred, y_true) -> Dict: + """Function for metrics calculation during step.""" + metric = self.hparams.metric(y_pred, y_true) + return metric + + def _log_current_lr(self) -> None: + opt = self.optimizers() + current_lr = opt.optimizer.param_groups[0]["lr"] + self.log("LR", current_lr, logger=True) + + def _reduce_metrics(self, step_type: str = "train"): + metric = self.hparams.metric.compute() + for key in metric.keys(): + self.log( + f"{step_type}/{key}", + metric[key], + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + self.hparams.metric.reset() + + def _check_devices(self): + self.hparams.metric = self.hparams.metric.to(self.device) + self.net.set() +# if self.ema is not None: +# self.ema.to(self.device) + +# def _instantiate_ema(self): +# if self.ema is not None: +# self.ema = self.ema(self.parameters()) + + def _get_batch_size(self, batch): + """Function for batch size infer.""" + bsz = batch.batch.max().detach().item() + 1 # get batch size + return bsz \ No newline at end of file From a8885e20c22995996535ba55f9a2645378789ddb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Wed, 13 Mar 2024 20:27:11 +0300 Subject: [PATCH 08/57] move schedulers, introduce new hyperparam for qhnet etc. --- config/model/qhnet.yaml | 24 +++-- config/qhnet.yaml | 2 +- nablaDFT/__init__.py | 2 + nablaDFT/graphormer/__init__.py | 1 - nablaDFT/graphormer/schedulers.py | 45 ---------- nablaDFT/qhnet/loss.py | 2 +- nablaDFT/qhnet/qhnet.py | 144 +++++++++++++----------------- 7 files changed, 85 insertions(+), 135 deletions(-) delete mode 100644 nablaDFT/graphormer/schedulers.py diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml index aa2e282..4d2baa7 100644 --- a/config/model/qhnet.yaml +++ b/config/model/qhnet.yaml @@ -3,6 +3,7 @@ _target_: nablaDFT.qhnet.QHNetLightning model_name: "QHNet" net: _target_: nablaDFT.qhnet.QHNet + _convert_: partial sh_lmax: 4 hidden_size: 128 bottle_hidden_size: 32 @@ -10,20 +11,28 @@ net: max_radius: 12 num_nodes: 83 radius_embed_dim: 32 + orbitals: + 1: [0, 0, 1] + 6: [0, 0, 0, 1, 1, 2] + 7: [0, 0, 0, 1, 1, 2] + 8: [0, 0, 0, 1, 1, 2] + 9: [0, 0, 0, 1, 1, 2] + 16: [0, 0, 0, 0, 1, 1, 1, 2] + 17: [0, 0, 0, 0, 1, 1, 1, 2] + 35: [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2] optimizer: _target_: torch.optim.AdamW _partial_: true amsgrad: true betas: [0.9, 0.95] - lr: 1e-3 - weight_decay: 0 + lr: 5e-4 lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _target_: nablaDFT.schedulers.get_polynomial_decay_schedule_with_warmup _partial_: true - factor: 0.8 - patience: 10 + num_warmup_steps: ${warmup_steps} + num_training_steps: ${max_steps} losses: hamiltonian: @@ -37,3 +46,8 @@ metric: task_metrics: hamiltonian: _target_: torchmetrics.MeanAbsoluteError + +ema: + _target_: torch_ema.ExponentialMovingAverage + _partial_: true + decay: 0.9999 \ No newline at end of file diff --git a/config/qhnet.yaml b/config/qhnet.yaml index 495be9b..63a996f 100644 --- a/config/qhnet.yaml +++ b/config/qhnet.yaml @@ -2,7 +2,7 @@ name: QHNet dataset_name: dataset_train_2k max_steps: 1000000 -warmup_steps: 0 +warmup_steps: 10000 job_type: train pretrained: False ckpt_path: null # path to checkpoint for training resume or test run diff --git a/nablaDFT/__init__.py b/nablaDFT/__init__.py index bc469f0..1dae3d5 100644 --- a/nablaDFT/__init__.py +++ b/nablaDFT/__init__.py @@ -7,3 +7,5 @@ from . import escn from . import ase_model from . import painn_pyg + +from .schedulers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup diff --git a/nablaDFT/graphormer/__init__.py b/nablaDFT/graphormer/__init__.py index 751acc5..1e95b58 100644 --- a/nablaDFT/graphormer/__init__.py +++ b/nablaDFT/graphormer/__init__.py @@ -1,2 +1 @@ from .graphormer_3d import Graphormer3DLightning, Graphormer3D -from .schedulers import get_linear_schedule_with_warmup diff --git a/nablaDFT/graphormer/schedulers.py b/nablaDFT/graphormer/schedulers.py deleted file mode 100644 index d941dc8..0000000 --- a/nablaDFT/graphormer/schedulers.py +++ /dev/null @@ -1,45 +0,0 @@ -from functools import partial -from torch.optim.lr_scheduler import LambdaLR - - -def get_linear_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_linear_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def _get_linear_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, - float(num_training_steps - current_step) - / float(max(1, num_training_steps - num_warmup_steps)), - ) diff --git a/nablaDFT/qhnet/loss.py b/nablaDFT/qhnet/loss.py index ada04b4..e2e6422 100644 --- a/nablaDFT/qhnet/loss.py +++ b/nablaDFT/qhnet/loss.py @@ -8,7 +8,7 @@ def __init__(self) -> None: super(MAE_RMSE_Loss, self).__init__() def forward(self, pred, target): - mse = F.mse_loss(pred, target, reduction=None) + mse = F.mse_loss(pred, target, reduction="none") mae = F.l1_loss(pred, target, reduction="mean") rmse = torch.sqrt(mse.mean()) return mae + rmse diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index 78e039a..d2d443f 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -1,5 +1,6 @@ from typing import Dict +import numpy as np import torch from torch import nn from torch.optim import Optimizer @@ -13,20 +14,11 @@ from .layers import ExponentialBernsteinRadialBasisFunctions, ConvNetLayer, PairNetLayer, SelfNetLayer, Expansion, get_nonlinear -ATOM_MASKS = { - 1: [0, 1, 5, 6, 7], - 6: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 7: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 8: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 9: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 16: [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 17: [0, 1, 2, 5, 6, 7, 8, 9, 10, 17, 18, 19, 20, 21], - 35: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] -} - - class QHNet(nn.Module): + """Modified QHNet from paper + Args: + orbitals (Dict): defines orbitals for each atom type from the dataset. + """ def __init__(self, in_node_features=1, sh_lmax=4, @@ -35,14 +27,10 @@ def __init__(self, num_gnn_layers=5, max_radius=12, num_nodes=10, - radius_embed_dim=32): # maximum nuclear charge (+1, i.e. 87 for up to Rn) for embeddings, can be kept at default + radius_embed_dim=32, # maximum nuclear charge (+1, i.e. 87 for up to Rn) for embeddings, can be kept at default + orbitals: Dict = None): super(QHNet, self).__init__() # store hyperparameter values - self.atom_orbs = [ - [[8, 0, '1s'], [8, 0, '2s'], [8, 0, '3s'], [8, 1, '2p'], [8, 1, '3p'], [8, 2, '3d']], - [[1, 0, '1s'], [1, 0, '2s'], [1, 1, '2p']], - [[1, 0, '1s'], [1, 0, '2s'], [1, 1, '2p']] - ] self.order = sh_lmax self.sh_irrep = o3.Irreps.spherical_harmonics(lmax=self.order) self.hs = hidden_size @@ -53,16 +41,14 @@ def __init__(self, self.node_embedding = nn.Embedding(num_nodes, self.hs) self.hidden_irrep = o3.Irreps(f'{self.hs}x0e + {self.hs}x1o + {self.hs}x2e + {self.hs}x3o + {self.hs}x4e') self.hidden_bottle_irrep = o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1o + {self.hbs}x2e + {self.hbs}x3o + {self.hbs}x4e') - self.hidden_irrep_base = o3.Irreps(f'{self.hs}x0e + {self.hs}x1e + {self.hs}x2e + {self.hs}x3e + {self.hs}x4e') - self.hidden_bottle_irrep_base = o3.Irreps( - f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e') - self.final_out_irrep = o3.Irreps(f'{self.hs * 3}x0e + {self.hs * 2}x1o + {self.hs}x2e').simplify() + self.hidden_irrep_base = o3.Irreps(f'{self.hs}x0e + {self.hs}x1e + {self.hs}x2e + {self.hs}x3e + {self.hs}x4e') # in use self.input_irrep = o3.Irreps(f'{self.hs}x0e') self.distance_expansion = ExponentialBernsteinRadialBasisFunctions(self.radius_embed_dim, self.max_radius) - self.nonlinear_scalars = {1: "ssp", -1: "tanh"} - self.nonlinear_gates = {1: "ssp", -1: "abs"} self.num_fc_layer = 1 + orbital_mask, max_s, max_p, max_d = self._get_mask(orbitals) # max_* used below to define output representation + self.orbital_mask = orbital_mask + self.e3_gnn_layer = nn.ModuleList() self.e3_gnn_node_pair_layer = nn.ModuleList() self.e3_gnn_node_layer = nn.ModuleList() @@ -112,8 +98,8 @@ def __init__(self, self.expand_ii[name] = Expansion( input_expand_ii, - o3.Irreps("3x0e + 2x1e + 1x2e"), - o3.Irreps("3x0e + 2x1e + 1x2e") + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e"), # here we define which basis we use + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e") # here we define which basis we use ) self.fc_ii[name] = torch.nn.Sequential( nn.Linear(self.hs, self.hs), @@ -125,11 +111,10 @@ def __init__(self, nn.SiLU(), nn.Linear(self.hs, self.expand_ii[name].num_bias) # TODO: this shit defines output dimension for diagonal block ) - self.expand_ij[name] = Expansion( o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e'), - o3.Irreps("3x0e + 2x1e + 1x2e"), - o3.Irreps("3x0e + 2x1e + 1x2e") + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e"), # here we define which basis we use + o3.Irreps(f"{max_s}x0e + {max_p}x1e + {max_d}x2e") # here we define which basis we use ) self.fc_ij[name] = torch.nn.Sequential( @@ -146,11 +131,10 @@ def __init__(self, self.output_ii = Linear(self.hidden_irrep, self.hidden_bottle_irrep) self.output_ij = Linear(self.hidden_irrep, self.hidden_bottle_irrep) - self.orbital_mask = {} def set(self): - for key in ATOM_MASKS.keys(): - self.orbital_mask[key] = torch.tensor(ATOM_MASKS[key]).to(self.device) + for key in self.orbital_mask.keys(): + self.orbital_mask[key] = self.orbital_mask[key].to(self.device) def get_number_of_parameters(self): num = 0 @@ -193,7 +177,6 @@ def forward(self, data, keep_blocks=False): hamiltonian_non_diagonal_matrix = self.expand_ij['hamiltonian']( fij, self.fc_ij['hamiltonian'](node_pair_embedding), self.fc_ij_bias['hamiltonian'](node_pair_embedding)) - if keep_blocks is False: hamiltonian_matrix = self.build_final_matrix( data, hamiltonian_diagonal_matrix, hamiltonian_non_diagonal_matrix) @@ -267,43 +250,29 @@ def build_final_matrix(self, data, diagonal_matrix, non_diagonal_matrix): -1, self.orbital_mask[data.z[src_idx].item()])) matrix_block_col.append(torch.cat(matrix_col, dim=-2)) final_matrix.append(torch.cat(matrix_block_col, dim=-1)) - final_matrix = torch.block_diag(final_matrix) + final_matrix = torch.block_diag(*final_matrix) return final_matrix - - def split_matrix(self, data): - diagonal_matrix, non_diagonal_matrix = \ - torch.zeros(data.z.shape[0], 14, 14).type(data.pos.type()).to(self.device), \ - torch.zeros(data.edge_index.shape[1], 14, 14).type(data.pos.type()).to(self.device) - - data.matrix = data.matrix.reshape( - len(data.ptr) - 1, data.matrix.shape[-1], data.matrix.shape[-1]) - - num_atoms = 0 - num_edges = 0 - for graph_idx in range(data.ptr.shape[0] - 1): - slices = [0] - for atom_idx in data.z[range(data.ptr[graph_idx], data.ptr[graph_idx + 1])]: - slices.append(slices[-1] + len(self.orbital_mask[atom_idx.item()])) - - for node_idx in range(data.ptr[graph_idx], data.ptr[graph_idx+1]): - node_idx = node_idx - num_atoms - orb_mask = self.orbital_mask[data.z[node_idx].item()] - diagonal_matrix[node_idx][orb_mask][:, orb_mask] = \ - data.matrix[graph_idx][slices[node_idx]: slices[node_idx+1], slices[node_idx]: slices[node_idx+1]] - - for edge_index_idx in range(num_edges, data.edge_index.shape[1]): - dst, src = data.edge_index[:, edge_index_idx] - if dst > data.ptr[graph_idx + 1] or src > data.ptr[graph_idx + 1]: - break - num_edges = num_edges + 1 - orb_mask_dst = self.orbital_mask[data.z[dst].item()] - orb_mask_src = self.orbital_mask[data.z[src].item()] - graph_dst, graph_src = dst - num_atoms, src - num_atoms - non_diagonal_matrix[edge_index_idx][orb_mask_dst][:, orb_mask_src] = \ - data.matrix[graph_idx][slices[graph_dst]: slices[graph_dst+1], slices[graph_src]: slices[graph_src+1]] - - num_atoms = num_atoms + data.ptr[graph_idx + 1] - data.ptr[graph_idx] - return diagonal_matrix, non_diagonal_matrix + + def _get_mask(self, orbitals): + # get orbitals by z + # retrieve max orbital and get ranges for mask + max_z = max(orbitals.keys()) + _, counts = np.unique(orbitals[max_z], return_counts=True) + s_max, p_max, d_max = counts # max orbital number per type + s_range = [i for i in range(s_max)] + p_range = [i + max(s_range) + 1 for i in range(p_max * 3)] + d_range = [i + max(p_range) + 1 for i in range(d_max * 5)] + ranges = [s_range, p_range, d_range] + orbs_count = [1, 3, 5] # orbital count per type + # create mask for each atom type + atom_orb_masks = {} + for atom in orbitals.keys(): + _, orb_count = np.unique(orbitals[atom], return_counts=True) + mask = [] + for idx, val in enumerate(orb_count): + mask.extend(ranges[idx][:orb_count[idx] * orbs_count[idx]]) + atom_orb_masks[atom] = torch.tensor(mask) + return atom_orb_masks, s_max, p_max, d_max class QHNetLightning(pl.LightningModule): @@ -314,11 +283,13 @@ def __init__( optimizer: Optimizer, lr_scheduler: LRScheduler, losses: Dict, + ema, metric, loss_coefs, ) -> None: super(QHNetLightning, self).__init__() self.net = net + self.ema = ema self.save_hyperparameters(logger=True) def forward(self, data: Data): @@ -354,8 +325,8 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - # with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + with self.ema.average_parameters(): + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "val/loss", loss, @@ -413,15 +384,15 @@ def configure_optimizers(self): } return {"optimizer": optimizer} -# def on_before_zero_grad(self, optimizer: Optimizer) -> None: -# self.ema.update() + def on_before_zero_grad(self, optimizer: Optimizer) -> None: + self.ema.update() def on_fit_start(self) -> None: - # self._instantiate_ema() + self._instantiate_ema() self._check_devices() def on_test_start(self) -> None: - # self._instantiate_ema() + self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -465,14 +436,23 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) self.net.set() -# if self.ema is not None: -# self.ema.to(self.device) + if self.ema is not None: + self.ema.to(self.device) -# def _instantiate_ema(self): -# if self.ema is not None: -# self.ema = self.ema(self.parameters()) + def _instantiate_ema(self): + if self.ema is not None: + self.ema = self.ema(self.parameters()) def _get_batch_size(self, batch): """Function for batch size infer.""" bsz = batch.batch.max().detach().item() + 1 # get batch size - return bsz \ No newline at end of file + return bsz + + def _get_hamiltonian_sizes(self, batch): + sizes = [] + for idx in range(batch.ptr.shape[0] - 1): + atoms = batch.z[batch.ptr[idx]: batch.ptr[idx + 1]] + size = sum([self.net.orbital_mask[atom] for atom in atoms]) + sizes.append(size) + return sizes + \ No newline at end of file From 910ddab1f7e900e5fd56d5cbf97242348cf088c9 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Wed, 13 Mar 2024 20:27:30 +0300 Subject: [PATCH 09/57] move scheduler --- nablaDFT/schedulers.py | 112 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 nablaDFT/schedulers.py diff --git a/nablaDFT/schedulers.py b/nablaDFT/schedulers.py new file mode 100644 index 0000000..48e364b --- /dev/null +++ b/nablaDFT/schedulers.py @@ -0,0 +1,112 @@ +from functools import partial +from torch.optim.lr_scheduler import LambdaLR + + +def get_linear_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_linear_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def _get_linear_schedule_with_warmup_lr_lambda( + current_step: int, *, num_warmup_steps: int, num_training_steps: int +): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps)), + ) + + +def _get_polynomial_decay_schedule_with_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float, + power: float, + lr_init: int, +): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + +def get_polynomial_decay_schedule_with_warmup( + optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 +): + """ + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT + implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + + """ + + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") + + lr_lambda = partial( + _get_polynomial_decay_schedule_with_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_end=lr_end, + power=power, + lr_init=lr_init, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) \ No newline at end of file From 0bdc4c71b44a1609b6d4f9d2c1059a8a1a78f6eb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 14 Mar 2024 15:32:28 +0300 Subject: [PATCH 10/57] qhnet update --- config/datamodule/nablaDFT_hamiltonian.yaml | 2 +- config/model/qhnet.yaml | 2 +- nablaDFT/dataset/nablaDFT_dataset.py | 2 +- nablaDFT/qhnet/loss.py | 17 ++++++++------- nablaDFT/qhnet/qhnet.py | 23 ++++++++++++++------- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/config/datamodule/nablaDFT_hamiltonian.yaml b/config/datamodule/nablaDFT_hamiltonian.yaml index 989b77f..3f55569 100644 --- a/config/datamodule/nablaDFT_hamiltonian.yaml +++ b/config/datamodule/nablaDFT_hamiltonian.yaml @@ -7,7 +7,7 @@ train_size: 0.9 val_size: 0.1 seed: 23 # Dataloader args -batch_size: 2 +batch_size: 8 num_workers: 4 persistent_workers: True pin_memory: True diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml index 4d2baa7..7b0bab3 100644 --- a/config/model/qhnet.yaml +++ b/config/model/qhnet.yaml @@ -36,7 +36,7 @@ lr_scheduler: losses: hamiltonian: - _target_: nablaDFT.qhnet.loss.MAE_RMSE_Loss + _target_: nablaDFT.qhnet.loss.HamiltonianLoss loss_coefs: hamiltonian: 1.0 diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 2b710e2..18e7682 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -172,7 +172,7 @@ def __init__( def setup(self, stage: str) -> None: if stage == "fit": - dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs)[:50] # TODO: temporal subset + dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed)) elif stage == "test": diff --git a/nablaDFT/qhnet/loss.py b/nablaDFT/qhnet/loss.py index e2e6422..d4d87a5 100644 --- a/nablaDFT/qhnet/loss.py +++ b/nablaDFT/qhnet/loss.py @@ -3,12 +3,15 @@ import torch.nn.functional as F -class MAE_RMSE_Loss(nn.Module): +class HamiltonianLoss(nn.Module): def __init__(self) -> None: - super(MAE_RMSE_Loss, self).__init__() + super(HamiltonianLoss, self).__init__() - def forward(self, pred, target): - mse = F.mse_loss(pred, target, reduction="none") - mae = F.l1_loss(pred, target, reduction="mean") - rmse = torch.sqrt(mse.mean()) - return mae + rmse + def forward(self, pred, target, mask): + diff = pred - target + mse = torch.mean(diff**2) + mae = torch.mean(torch.abs(diff)) + mse *= (pred.numel() / mask.sum()) + mae *= (pred.numel() / mask.sum()) + rmse = torch.sqrt(mse) + return rmse + mae diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index d2d443f..0e5a1e2 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -290,7 +290,7 @@ def __init__( super(QHNetLightning, self).__init__() self.net = net self.ema = ema - self.save_hyperparameters(logger=True) + self.save_hyperparameters(logger=True, ignore=['net']) def forward(self, data: Data): hamiltonian = self.net(data) @@ -300,11 +300,12 @@ def step(self, batch, calculate_metrics: bool = False): hamiltonian_out = self.net(batch) hamiltonian = batch.hamiltonian preds = {'hamiltonian': hamiltonian_out} - hamiltonian = torch.block_diag(*[torch.from_numpy(H) for H in hamiltonian]) + masks = torch.block_diag(*[torch.ones_like(torch.from_numpy(H)) for H in hamiltonian]) + hamiltonian = torch.block_diag(*[torch.from_numpy(H) for H in hamiltonian]).to(self.device) target = {'hamiltonian': hamiltonian} - loss = self._calculate_loss(preds, target) + loss = self._calculate_loss(preds, target, masks) if calculate_metrics: - metrics = self._calculate_metrics(preds, target) + metrics = self._calculate_metrics(preds, target, masks) return loss, metrics return loss @@ -401,18 +402,26 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") - def _calculate_loss(self, y_pred, y_true) -> float: +# def on_after_backward(self): +# for name, param in self.named_parameters(): +# if param.grad is None: +# print(name) + + def _calculate_loss(self, y_pred, y_true, masks) -> float: # Note: since hamiltonians has different shapes, loss calculated per sample total_loss = 0.0 for name, loss in self.hparams.losses.items(): total_loss += self.hparams.loss_coefs[name] * loss( - y_pred[name], y_true[name] + y_pred[name], y_true[name], masks ) return total_loss - def _calculate_metrics(self, y_pred, y_true) -> Dict: + def _calculate_metrics(self, y_pred, y_true, mask) -> Dict: """Function for metrics calculation during step.""" + # TODO: temp workaround for metric normalization by mask sum + norm_coef = (y_pred['hamiltonian'].numel() / mask.sum()) metric = self.hparams.metric(y_pred, y_true) + metric['hamiltonian'] = metric['hamiltonian'] * norm_coef return metric def _log_current_lr(self) -> None: From 0c572f78613eae4f799ea03f779f47c9a124fda4 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 13:12:22 +0300 Subject: [PATCH 11/57] change scheduler+ change scheduler for qhnet --- config/model/qhnet.yaml | 7 ++++--- nablaDFT/schedulers.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml index 7b0bab3..f630b1f 100644 --- a/config/model/qhnet.yaml +++ b/config/model/qhnet.yaml @@ -29,10 +29,11 @@ optimizer: lr: 5e-4 lr_scheduler: - _target_: nablaDFT.schedulers.get_polynomial_decay_schedule_with_warmup + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true - num_warmup_steps: ${warmup_steps} - num_training_steps: ${max_steps} + factor: 0.8 + patience: 10 + min_lr: 1e-6 losses: hamiltonian: diff --git a/nablaDFT/schedulers.py b/nablaDFT/schedulers.py index 48e364b..e1ad7e9 100644 --- a/nablaDFT/schedulers.py +++ b/nablaDFT/schedulers.py @@ -54,6 +54,7 @@ def _get_polynomial_decay_schedule_with_warmup_lr_lambda( power: float, lr_init: int, ): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L218 if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: @@ -69,6 +70,7 @@ def _get_polynomial_decay_schedule_with_warmup_lr_lambda( def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): + # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L239 """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the From 304508f59f492c65dee665bbcf35707f25e2b481 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 13:58:10 +0300 Subject: [PATCH 12/57] modified setup.py: precise versions for crucial dependencies --- setup.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 5089771..8f933c2 100644 --- a/setup.py +++ b/setup.py @@ -18,27 +18,28 @@ def read(fname): include_package_data=True, python_requires=">=3.6", install_requires=[ - "numpy", - "sympy", - "ase>=3.21", - "h5py", - "apsw", - "schnetpack>=2.0.0", + "numpy>=1.26", + "sympy==1.12", + "ase==3.22.1", + "h5py==3.10.0", + "apsw==3.45.1.0", + "schnetpack==2.0.4", "tensorboardX", "pyyaml", - "hydra-core>=1.1.0", - "pytorch_lightning>=1.9.0", - "torch-geometric>=2.3.1", - "torchmetrics", + "hydra-core==1.2.0", + "torch==2.2.0", + "pytorch_lightning==2.1.4", + "torch-geometric==2.4.0", + "torchmetrics==1.0.1", "hydra-colorlog>=1.1.0", "rich", "fasteners", "dirsync", - "torch-ema", + "torch-ema==0.3", "matscipy", "python-dotenv", - "wandb", - "e3nn" + "wandb==0.16.3", + "e3nn==0.5.1" ], license="MIT", description="nablaDFT: Large-Scale Conformational Energy and Hamiltonian Prediction benchmark and dataset", From 0e4f8ec6b2c3bd34e1148562929b0d0070a1455f Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 14:01:25 +0300 Subject: [PATCH 13/57] add dependencies for PyG --- setup.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/setup.py b/setup.py index 8f933c2..6818d7a 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,10 @@ from setuptools import setup, find_packages + +CUDA = "cu121" + + def read(fname): with io.open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as f: return f.read() @@ -28,6 +32,9 @@ def read(fname): "pyyaml", "hydra-core==1.2.0", "torch==2.2.0", + "torch-scatter @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", + "torch-sparse @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", + "torch-cluster @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", "pytorch_lightning==2.1.4", "torch-geometric==2.4.0", "torchmetrics==1.0.1", From 26ef09687c8c246f4a838e94d6093d8a1a7b89b3 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 14:09:49 +0300 Subject: [PATCH 14/57] remove old links, updated names --- nablaDFT/links/energy_databases.json | 24 +++++++++++++----------- nablaDFT/links/energy_databases_v2.json | 18 ------------------ 2 files changed, 13 insertions(+), 29 deletions(-) delete mode 100644 nablaDFT/links/energy_databases_v2.json diff --git a/nablaDFT/links/energy_databases.json b/nablaDFT/links/energy_databases.json index f3791c5..c94874c 100644 --- a/nablaDFT/links/energy_databases.json +++ b/nablaDFT/links/energy_databases.json @@ -1,16 +1,18 @@ { "train_databases": { - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_100k_energy.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_10k_energy.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_5k_energy.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_2k_energy.db" + "dataset_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_full_v2_formation_energy.db", + "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train100k_v2_formation_energy_w_forces_wo_outliers.db", + "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train10k_v2_formation_energy_w_forces.db", + "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train5k_v2_formation_energy_w_forces.db", + "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train2k_v2_formation_energy_w_forces.db" }, "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_random.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_scaffolds.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers.db.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers.db" + "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_structures_v2_formation_energy.db", + "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_scaffolds_v2_formation_energy.db", + "dataset_test_conformations_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_conformers_v2_formation_energy.db", + "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers_v2_formation_energy_w_forces.db" } -} \ No newline at end of file +} diff --git a/nablaDFT/links/energy_databases_v2.json b/nablaDFT/links/energy_databases_v2.json deleted file mode 100644 index c94874c..0000000 --- a/nablaDFT/links/energy_databases_v2.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "train_databases": { - "dataset_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_full_v2_formation_energy.db", - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train100k_v2_formation_energy_w_forces_wo_outliers.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train10k_v2_formation_energy_w_forces.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train5k_v2_formation_energy_w_forces.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train2k_v2_formation_energy_w_forces.db" - }, - "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_structures_v2_formation_energy.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_scaffolds_v2_formation_energy.db", - "dataset_test_conformations_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_conformers_v2_formation_energy.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers_v2_formation_energy_w_forces.db" - } -} From 47fb67f550b0fc116dda3f9a12e7e755f7d6d844 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 16:14:43 +0300 Subject: [PATCH 15/57] datasets links updated --- nablaDFT/links/energy_databases.json | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/nablaDFT/links/energy_databases.json b/nablaDFT/links/energy_databases.json index c94874c..6657d5f 100644 --- a/nablaDFT/links/energy_databases.json +++ b/nablaDFT/links/energy_databases.json @@ -1,18 +1,17 @@ { "train_databases": { - "dataset_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train_full_v2_formation_energy.db", - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train100k_v2_formation_energy_w_forces_wo_outliers.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train10k_v2_formation_energy_w_forces.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train5k_v2_formation_energy_w_forces.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/train2k_v2_formation_energy_w_forces.db" + "dataset_train_full": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_full_v2_formation_energy.db", + "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_100k_v2_formation_energy_w_forces.db", + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_10k_v2_formation_energy_w_forces.db", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_5k_v2_formation_energy_w_forces.db", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_2k_v2_formation_energy_w_forces.db" }, "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_structures_v2_formation_energy.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_scaffolds_v2_formation_energy.db", - "dataset_test_conformations_full": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_full_conformers_v2_formation_energy.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_100k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_10k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_5k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/energy_dbs/test_2k_conformers_v2_formation_energy_w_forces.db" + "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_structures_v2_formation_energy_forces.db", + "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_scaffolds_v2_formation_energy_forces.db", + "dataset_test_conformations_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_100k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_10k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_5k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_2k_conformers_v2_formation_energy_w_forces.db" } } From 46f1a42d3bd1b780e0641694bdc240d0ad300f88 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 16:31:57 +0300 Subject: [PATCH 16/57] fix setup --- setup.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6818d7a..642030f 100644 --- a/setup.py +++ b/setup.py @@ -32,9 +32,9 @@ def read(fname): "pyyaml", "hydra-core==1.2.0", "torch==2.2.0", - "torch-scatter @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", - "torch-sparse @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", - "torch-cluster @ https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html", + "torch-scatter", + "torch-sparse ", + "torch-cluster", "pytorch_lightning==2.1.4", "torch-geometric==2.4.0", "torchmetrics==1.0.1", @@ -48,6 +48,11 @@ def read(fname): "wandb==0.16.3", "e3nn==0.5.1" ], + dependency_links=[ + f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html", + f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html", + f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html" + ] license="MIT", description="nablaDFT: Large-Scale Conformational Energy and Hamiltonian Prediction benchmark and dataset", long_description=""" From d294eb896c9b5397a4901c7441e7afd8c70dd643 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 15 Mar 2024 16:44:43 +0300 Subject: [PATCH 17/57] fix setup --- setup.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 642030f..baaceb4 100644 --- a/setup.py +++ b/setup.py @@ -32,9 +32,9 @@ def read(fname): "pyyaml", "hydra-core==1.2.0", "torch==2.2.0", - "torch-scatter", - "torch-sparse ", - "torch-cluster", + "torch-scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", + "torch-sparse @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_sparse-0.6.18%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", + "torch-cluster @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_cluster-1.6.3%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", "pytorch_lightning==2.1.4", "torch-geometric==2.4.0", "torchmetrics==1.0.1", @@ -48,11 +48,6 @@ def read(fname): "wandb==0.16.3", "e3nn==0.5.1" ], - dependency_links=[ - f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html", - f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html", - f"https://data.pyg.org/whl/torch-2.2.0+{CUDA}.html" - ] license="MIT", description="nablaDFT: Large-Scale Conformational Energy and Hamiltonian Prediction benchmark and dataset", long_description=""" From 58d8965762e75a27d7bc39f75b8e8a803f8b5c93 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 12:11:13 +0300 Subject: [PATCH 18/57] fix graphormer3d cfg --- config/model/graphormer3d-half.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/model/graphormer3d-half.yaml b/config/model/graphormer3d-half.yaml index dbbebaa..a440b66 100644 --- a/config/model/graphormer3d-half.yaml +++ b/config/model/graphormer3d-half.yaml @@ -20,7 +20,7 @@ optimizer: lr: 3e-4 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} From 900a5e7b4c1a5a1d00f106f53e14196a3347dce7 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 12:21:13 +0300 Subject: [PATCH 19/57] make model download more verbose --- nablaDFT/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index 4aa579e..297b624 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -100,7 +100,8 @@ def download_model(config: DictConfig) -> str: with open(nablaDFT.__path__[0] + "/links/models_checkpoints.json", "r") as f: data = json.load(f) url = data[f"{model_name}"]["dataset_train_100k"] - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=f"Downloading {model_name} checkpoint") as t: + file_size = get_file_size(url) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading {model_name} checkpoint") as t: request.urlretrieve(url, ckpt_path, reporthook=tqdm_download_hook(t)) logging.info(f"Downloaded {model_name} 100k checkpoint to {ckpt_path}") return ckpt_path From a870b7d8727563c9394420026b1a9ad48653645e Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 12:37:34 +0300 Subject: [PATCH 20/57] rename equiformer model name --- config/equiformer_v2_oc20.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/equiformer_v2_oc20.yaml b/config/equiformer_v2_oc20.yaml index d9712bb..547daa9 100644 --- a/config/equiformer_v2_oc20.yaml +++ b/config/equiformer_v2_oc20.yaml @@ -1,5 +1,5 @@ # Global variables -name: Equiformer_v2_OC20 +name: Equiformer_v2 dataset_name: dataset_train_2k max_steps: 1000000 warmup_steps: 0 From 66894d960892dc0f3328d1f24579549b488399b5 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 12:38:51 +0300 Subject: [PATCH 21/57] remove temp workaround for dataset w/o forces --- nablaDFT/dataset/pyg_datasets.py | 7 ++----- nablaDFT/dimenetplusplus/dimenetplusplus.py | 8 +------- nablaDFT/equiformer_v2/equiformer_v2_oc20.py | 6 +----- nablaDFT/escn/escn.py | 6 +----- nablaDFT/gemnet_oc/gemnet_oc.py | 5 +---- nablaDFT/graphormer/graphormer_3d.py | 16 +++++----------- nablaDFT/painn_pyg/painn.py | 6 +----- 7 files changed, 12 insertions(+), 42 deletions(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index e7d2c9f..457d560 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -70,7 +70,7 @@ def get(self, idx): return super(PyGNablaDFT, self).get(idx - self.offsets[data_idx]) def download(self) -> None: - with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json", "r") as f: + with open(nablaDFT.__path__[0] + "/links/energy_databases.json", "r") as f: data = json.load(f) url = data[f"{self.split}_databases"][self.dataset_name] file_size = get_file_size(url) @@ -84,10 +84,7 @@ def process(self) -> None: z = torch.from_numpy(db_row.numbers).long() positions = torch.from_numpy(db_row.positions).float() y = torch.from_numpy(np.array(db_row.data["energy"])).float() - # TODO: temp workaround for dataset w/o forces - forces = db_row.data.get("forces", None) - if forces is not None: - forces = torch.from_numpy(np.array(forces)).float() + forces = torch.from_numpy(np.array(db_row.data["forces"])).float() samples.append(Data(z=z, pos=positions, y=y, forces=forces)) if self.pre_filter is not None: diff --git a/nablaDFT/dimenetplusplus/dimenetplusplus.py b/nablaDFT/dimenetplusplus/dimenetplusplus.py index b6dd06a..6b60464 100644 --- a/nablaDFT/dimenetplusplus/dimenetplusplus.py +++ b/nablaDFT/dimenetplusplus/dimenetplusplus.py @@ -144,13 +144,7 @@ def step( ) -> Union[Tuple[Any, Dict], Any]: predictions_energy, predictions_forces = self.forward(batch) loss_energy = self.loss(predictions_energy, batch.y) - # TODO: temp workaround - if hasattr(batch, "forces"): - loss_forces = self.loss(predictions_forces, batch.forces) - else: - loss_forces = torch.zeros(1).to(self.device) - predictions_forces = torch.zeros(1).to(self.device) - forces = torch.zeros(1).to(self.device) + loss_forces = self.loss(predictions_forces, batch.forces) loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy if calculate_metrics: preds = {"energy": predictions_energy, "forces": predictions_forces} diff --git a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py index 1de12e2..070885a 100644 --- a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py +++ b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py @@ -698,11 +698,7 @@ def step(self, batch, calculate_metrics: bool = False): y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) diff --git a/nablaDFT/escn/escn.py b/nablaDFT/escn/escn.py index 72af98a..cded69b 100644 --- a/nablaDFT/escn/escn.py +++ b/nablaDFT/escn/escn.py @@ -1074,11 +1074,7 @@ def step(self, batch, calculate_metrics: bool = False): y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) diff --git a/nablaDFT/gemnet_oc/gemnet_oc.py b/nablaDFT/gemnet_oc/gemnet_oc.py index aaf050f..e8bf524 100644 --- a/nablaDFT/gemnet_oc/gemnet_oc.py +++ b/nablaDFT/gemnet_oc/gemnet_oc.py @@ -1432,10 +1432,7 @@ def forward(self, data: Data): def step(self, batch, calculate_metrics: bool = False): energy_out, forces_out = self.net(batch) # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": batch.y, "forces": forces} loss = self._calculate_loss(preds, target) diff --git a/nablaDFT/graphormer/graphormer_3d.py b/nablaDFT/graphormer/graphormer_3d.py index 1e970c7..8a7fb2d 100644 --- a/nablaDFT/graphormer/graphormer_3d.py +++ b/nablaDFT/graphormer/graphormer_3d.py @@ -380,17 +380,11 @@ def step( y = batch.y energy_out, forces_out, mask_out = self(batch) loss_energy = self.loss(energy_out, y) - # TODO: temp workaround for datasets w/o forces - if hasattr(batch, "forces"): - forces, mask_forces = to_dense_batch( - batch.forces, batch.batch, batch_size=bsz - ) - masked_forces_out = forces_out * mask_forces.unsqueeze(-1) - loss_forces = self.loss(masked_forces_out, forces) - else: - loss_forces = torch.zeros(1).to(self.device) - masked_forces_out = torch.zeros(1).to(self.device) - forces = torch.zeros(1).to(self.device) + forces, mask_forces = to_dense_batch( + batch.forces, batch.batch, batch_size=bsz + ) + masked_forces_out = forces_out * mask_forces.unsqueeze(-1) + loss_forces = self.loss(masked_forces_out, forces) loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy if calculate_metrics: preds = {"energy": energy_out, "forces": masked_forces_out} diff --git a/nablaDFT/painn_pyg/painn.py b/nablaDFT/painn_pyg/painn.py index 2650ddf..1ef8bd7 100644 --- a/nablaDFT/painn_pyg/painn.py +++ b/nablaDFT/painn_pyg/painn.py @@ -722,11 +722,7 @@ def step( y = batch.y # make dense batch from PyG batch energy_out, forces_out = self.net(batch) - # TODO: temp workaround - if hasattr(batch, "forces"): - forces = batch.forces - else: - forces = forces_out.clone() + forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": y, "forces": forces} loss = self._calculate_loss(preds, target) From 571de42f7c0e35fd7a199307fde8201c1c0ae4b6 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 13:49:55 +0300 Subject: [PATCH 22/57] updated models checkpoints links --- nablaDFT/links/models_checkpoints.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 628dc3f..20906f2 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -7,16 +7,16 @@ }, "PaiNN": { "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_100k.ckpt", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_10k.ckpt", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_5k_split.pt", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/models_checkpoints/painn/painn_2k_split.pt" + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_10k.ckpt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_5k.ckpt", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_2k.ckpt" }, "DimeNet++": { "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_100k_epoch=0258.ckpt", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_10k_epoch=0651.ckpt", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_5k_epoch=0545.ckpt", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_2k_epoch=0442.ckpt" + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_10k_epoch=0651.ckpt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_5k_epoch=0442.ckpt", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_2k_epoch=0545.ckpt" }, "PhiSNet": { From 57256b6eecbf7806a2c9cd5e8b94c3a96fd3004e Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 17:12:35 +0300 Subject: [PATCH 23/57] qhnet fixes --- nablaDFT/qhnet/layers.py | 75 ++++++++++++++-------------------------- nablaDFT/qhnet/qhnet.py | 8 +---- 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/nablaDFT/qhnet/layers.py b/nablaDFT/qhnet/layers.py index 16f53f8..ae8fdf4 100644 --- a/nablaDFT/qhnet/layers.py +++ b/nablaDFT/qhnet/layers.py @@ -172,7 +172,7 @@ def __init__( self.edge_attr_dim = edge_attr_dim self.node_attr_dim = node_attr_dim self.edge_wise = edge_wise - + self.use_norm_gate = use_norm_gate self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) self.irrep_hidden = irrep_hidden \ if isinstance(irrep_hidden, o3.Irreps) else o3.Irreps(irrep_hidden) @@ -213,25 +213,25 @@ def __init__( shared_weights=True, biases=True ) - - self.use_norm_gate = use_norm_gate - self.norm_gate = NormGate(self.irrep_in_node) self.irrep_linear_out, instruction_node = get_feasible_irrep( self.irrep_in_node, o3.Irreps("0e"), self.irrep_in_node) - self.linear_node = Linear( - irreps_in=self.irrep_in_node, - irreps_out=self.irrep_linear_out, - internal_weights=True, - shared_weights=True, - biases=True - ) - self.linear_node_pre = Linear( - irreps_in=self.irrep_in_node, - irreps_out=self.irrep_linear_out, - internal_weights=True, - shared_weights=True, - biases=True - ) + if use_norm_gate: + # if this first layer, then it doesn't need this + self.norm_gate = NormGate(self.irrep_in_node) + self.linear_node = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pre = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) self.inner_product = InnerProduct(self.irrep_in_node) def forward(self, data, x): @@ -399,29 +399,11 @@ def __init__(self, self.irrep_tp_out_node_pair_2, instruction_node_pair_2 = get_feasible_irrep( self.irrep_tp_out_node_pair, self.irrep_tp_out_node_pair, self.irrep_bottle_hidden, tp_mode='uuu') - self.tp_node_pair_2 = TensorProduct( - self.irrep_tp_out_node_pair, - self.irrep_tp_out_node_pair, - self.irrep_tp_out_node_pair_2, - instruction_node_pair_2, - shared_weights=True, - internal_weights=True - ) - - self.fc_node_pair = FullyConnectedNet( [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node_pair.weight_numel], self.nonlinear_layer ) - self.linear_node_pair_2 = Linear( - irreps_in=self.irrep_tp_out_node_pair_2, - irreps_out=self.irrep_out, - internal_weights=True, - shared_weights=True, - biases=True - ) - if self.irrep_in_node == self.irrep_out and resnet: self.resnet = True else: @@ -588,19 +570,14 @@ def forward(self, x_in, weights=None, bias_weights=None): x1 = x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) w3j_matrix = o3.wigner_3j(ins[1], ins[2], ins[0]).to(self.device).type(x1.type()) if ins[3] is True or weights is not None: - if weights is None: - weight = self.weights[flat_weight_index:flat_weight_index + prod(ins[-1])].reshape(ins[-1]) - result = torch.einsum( - f"wuv, ijk, bwk-> buivj", weight, w3j_matrix, x1) / mul_ir_in.mul - else: - weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) - result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) - if ins[0] == 0 and bias_weights is not None: - bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ - reshape([-1] + ins[-1][1:]) - bias_weight_index += prod(ins[-1][1:]) - result = result + bias_weight.unsqueeze(-1) - result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul + weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) + result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) + if ins[0] == 0 and bias_weights is not None: + bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ + reshape([-1] + ins[-1][1:]) + bias_weight_index += prod(ins[-1][1:]) + result = result + bias_weight.unsqueeze(-1) + result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul flat_weight_index += prod(ins[-1]) else: result = torch.einsum( diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index 0e5a1e2..014acb7 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -109,7 +109,7 @@ def __init__(self, self.fc_ii_bias[name] = torch.nn.Sequential( nn.Linear(self.hs, self.hs), nn.SiLU(), - nn.Linear(self.hs, self.expand_ii[name].num_bias) # TODO: this shit defines output dimension for diagonal block + nn.Linear(self.hs, self.expand_ii[name].num_bias) ) self.expand_ij[name] = Expansion( o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e'), @@ -402,13 +402,7 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") -# def on_after_backward(self): -# for name, param in self.named_parameters(): -# if param.grad is None: -# print(name) - def _calculate_loss(self, y_pred, y_true, masks) -> float: - # Note: since hamiltonians has different shapes, loss calculated per sample total_loss = 0.0 for name, loss in self.hparams.losses.items(): total_loss += self.hparams.loss_coefs[name] * loss( From c757a9d05e85eb448be345cb06de3efe16526276 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 17:18:03 +0300 Subject: [PATCH 24/57] qhnet fix --- nablaDFT/qhnet/qhnet.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index 0e5a1e2..014acb7 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -109,7 +109,7 @@ def __init__(self, self.fc_ii_bias[name] = torch.nn.Sequential( nn.Linear(self.hs, self.hs), nn.SiLU(), - nn.Linear(self.hs, self.expand_ii[name].num_bias) # TODO: this shit defines output dimension for diagonal block + nn.Linear(self.hs, self.expand_ii[name].num_bias) ) self.expand_ij[name] = Expansion( o3.Irreps(f'{self.hbs}x0e + {self.hbs}x1e + {self.hbs}x2e + {self.hbs}x3e + {self.hbs}x4e'), @@ -402,13 +402,7 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") -# def on_after_backward(self): -# for name, param in self.named_parameters(): -# if param.grad is None: -# print(name) - def _calculate_loss(self, y_pred, y_true, masks) -> float: - # Note: since hamiltonians has different shapes, loss calculated per sample total_loss = 0.0 for name, loss in self.hparams.losses.items(): total_loss += self.hparams.loss_coefs[name] * loss( From 3ca20772b52ebd6b60235d1394a92cb383545b4e Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 18 Mar 2024 17:19:10 +0300 Subject: [PATCH 25/57] qhnet layers fix --- nablaDFT/qhnet/layers.py | 75 ++++++++++++++-------------------------- 1 file changed, 26 insertions(+), 49 deletions(-) diff --git a/nablaDFT/qhnet/layers.py b/nablaDFT/qhnet/layers.py index 16f53f8..ae8fdf4 100644 --- a/nablaDFT/qhnet/layers.py +++ b/nablaDFT/qhnet/layers.py @@ -172,7 +172,7 @@ def __init__( self.edge_attr_dim = edge_attr_dim self.node_attr_dim = node_attr_dim self.edge_wise = edge_wise - + self.use_norm_gate = use_norm_gate self.irrep_in_node = irrep_in_node if isinstance(irrep_in_node, o3.Irreps) else o3.Irreps(irrep_in_node) self.irrep_hidden = irrep_hidden \ if isinstance(irrep_hidden, o3.Irreps) else o3.Irreps(irrep_hidden) @@ -213,25 +213,25 @@ def __init__( shared_weights=True, biases=True ) - - self.use_norm_gate = use_norm_gate - self.norm_gate = NormGate(self.irrep_in_node) self.irrep_linear_out, instruction_node = get_feasible_irrep( self.irrep_in_node, o3.Irreps("0e"), self.irrep_in_node) - self.linear_node = Linear( - irreps_in=self.irrep_in_node, - irreps_out=self.irrep_linear_out, - internal_weights=True, - shared_weights=True, - biases=True - ) - self.linear_node_pre = Linear( - irreps_in=self.irrep_in_node, - irreps_out=self.irrep_linear_out, - internal_weights=True, - shared_weights=True, - biases=True - ) + if use_norm_gate: + # if this first layer, then it doesn't need this + self.norm_gate = NormGate(self.irrep_in_node) + self.linear_node = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) + self.linear_node_pre = Linear( + irreps_in=self.irrep_in_node, + irreps_out=self.irrep_linear_out, + internal_weights=True, + shared_weights=True, + biases=True + ) self.inner_product = InnerProduct(self.irrep_in_node) def forward(self, data, x): @@ -399,29 +399,11 @@ def __init__(self, self.irrep_tp_out_node_pair_2, instruction_node_pair_2 = get_feasible_irrep( self.irrep_tp_out_node_pair, self.irrep_tp_out_node_pair, self.irrep_bottle_hidden, tp_mode='uuu') - self.tp_node_pair_2 = TensorProduct( - self.irrep_tp_out_node_pair, - self.irrep_tp_out_node_pair, - self.irrep_tp_out_node_pair_2, - instruction_node_pair_2, - shared_weights=True, - internal_weights=True - ) - - self.fc_node_pair = FullyConnectedNet( [self.edge_attr_dim] + invariant_layers * [invariant_neurons] + [self.tp_node_pair.weight_numel], self.nonlinear_layer ) - self.linear_node_pair_2 = Linear( - irreps_in=self.irrep_tp_out_node_pair_2, - irreps_out=self.irrep_out, - internal_weights=True, - shared_weights=True, - biases=True - ) - if self.irrep_in_node == self.irrep_out and resnet: self.resnet = True else: @@ -588,19 +570,14 @@ def forward(self, x_in, weights=None, bias_weights=None): x1 = x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim) w3j_matrix = o3.wigner_3j(ins[1], ins[2], ins[0]).to(self.device).type(x1.type()) if ins[3] is True or weights is not None: - if weights is None: - weight = self.weights[flat_weight_index:flat_weight_index + prod(ins[-1])].reshape(ins[-1]) - result = torch.einsum( - f"wuv, ijk, bwk-> buivj", weight, w3j_matrix, x1) / mul_ir_in.mul - else: - weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) - result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) - if ins[0] == 0 and bias_weights is not None: - bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ - reshape([-1] + ins[-1][1:]) - bias_weight_index += prod(ins[-1][1:]) - result = result + bias_weight.unsqueeze(-1) - result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul + weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1]) + result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1) + if ins[0] == 0 and bias_weights is not None: + bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\ + reshape([-1] + ins[-1][1:]) + bias_weight_index += prod(ins[-1][1:]) + result = result + bias_weight.unsqueeze(-1) + result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul flat_weight_index += prod(ins[-1]) else: result = torch.einsum( From b6f5b45a815ea059c41fee6853b8a3101b80ea6a Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 19 Mar 2024 15:09:39 +0300 Subject: [PATCH 26/57] make hamiltonian dataset reading from disk --- config/trainer/train.yaml | 3 +- nablaDFT/dataset/pyg_datasets.py | 86 ++++++++++++-------------------- 2 files changed, 33 insertions(+), 56 deletions(-) diff --git a/config/trainer/train.yaml b/config/trainer/train.yaml index 093f8f4..2dffd85 100644 --- a/config/trainer/train.yaml +++ b/config/trainer/train.yaml @@ -5,7 +5,8 @@ accelerator: "gpu" devices: [0] strategy: _target_: pytorch_lightning.strategies.ddp.DDPStrategy - + # QHNet has unused params, uncomment line for train + # find_unused_parameters: True max_steps: ${max_steps} # example of additional arguments for trainer diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 457d560..d7e2131 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -9,7 +9,7 @@ import numpy as np import torch from ase.db import connect -from torch_geometric.data import InMemoryDataset, Data +from torch_geometric.data import InMemoryDataset, Data, Dataset import nablaDFT from .hamiltonian_dataset import HamiltonianDatabase @@ -98,8 +98,7 @@ def process(self) -> None: logger.info(f"Saved processed dataset: {self.processed_paths[0]}") -# TODO: move this to OnDiskDataset -class PyGHamiltonianNablaDFT(InMemoryDataset): +class PyGHamiltonianNablaDFT(Dataset): """Pytorch Geometric dataset for NablaDFT Hamiltonian database. Args: @@ -111,7 +110,7 @@ class PyGHamiltonianNablaDFT(InMemoryDataset): - include_core (bool): if True, retrieves core Hamiltonian matrices from database. - dtype (torch.dtype): defines torch.dtype for energy, positions, forces tensors. - transform (Callable): callable data transform, called on every access to element. - - pre_transform (Callable): callable data transform, called during process() for every element. + - pre_transform (Callable): callable data transform, called on every access to element. Note: Hamiltonian matrix for each molecule has different shape. PyTorch Geometric tries to concatenate each torch.Tensor in batch, so in order to make batch from data we leave all hamiltonian matrices @@ -152,26 +151,37 @@ def __init__( super(PyGHamiltonianNablaDFT, self).__init__(datapath, transform, pre_transform) self.max_orbitals = self._get_max_orbitals(datapath, dataset_name) - for path in self.processed_paths: - data, slices = torch.load(path) - self.data_all.append(data) - self.slices_all.append(slices) - self.offsets.append( - len(slices[list(slices.keys())[0]]) - 1 + self.offsets[-1] - ) + self.db = HamiltonianDatabase(self.raw_paths[0]) def len(self) -> int: - return sum( - len(slices[list(slices.keys())[0]]) - 1 for slices in self.slices_all - ) + return len(self.db) def get(self, idx): - data_idx = 0 - while data_idx < len(self.data_all) - 1 and idx >= self.offsets[data_idx + 1]: - data_idx += 1 - self.data = self.data_all[data_idx] - self.slices = self.slices_all[data_idx] - return super(PyGHamiltonianNablaDFT, self).get(idx - self.offsets[data_idx]) + data = self.db[idx] + z = torch.tensor(data[0]).long() + positions = torch.tensor(data[1]).to(self.dtype) + # see notes + hamiltonian = data[4] + if self.include_overlap: + overlap = data[5] + else: + overlap = None + if self.include_core: + core = data[6] + else: + core = None + y = torch.from_numpy(data[2]).to(self.dtype) + forces = torch.from_numpy(data[3]).to(self.dtype) + data = Data( + z=z, pos=positions, + y=y, forces=forces, + hamiltonian=hamiltonian, + overlap=overlap, + core=core, + ) + if self.pre_transform is not None: + data = self.pre_transform(data) + return data def download(self) -> None: with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: @@ -182,41 +192,7 @@ def download(self) -> None: request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) def process(self) -> None: - database = HamiltonianDatabase(self.raw_paths[0]) - samples = [] - for idx in tqdm(range(len(database)), total=len(database)): - data = database[idx] - z = torch.tensor(data[0]).long() - positions = torch.tensor(data[1]).to(self.dtype) - # see notes - hamiltonian = data[4] - if self.include_overlap: - overlap = data[5] - else: - overlap = None - if self.include_core: - core = data[6] - else: - core = None - y = torch.from_numpy(data[2]).to(self.dtype) - forces = torch.from_numpy(data[3]).to(self.dtype) - samples.append(Data( - z=z, pos=positions, - y=y, forces=forces, - hamiltonian=hamiltonian, - overlap=overlap, - core=core, - )) - - if self.pre_filter is not None: - samples = [data for data in samples if self.pre_filter(data)] - - if self.pre_transform is not None: - samples = [self.pre_transform(data) for data in samples] - - data, slices = self.collate(samples) - torch.save((data, slices), self.processed_paths[0]) - logger.info(f"Saved processed dataset: {self.processed_paths[0]}") + pass def _get_max_orbitals(self, datapath, dataset_name): db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) From 0c7a3441956c16105322d49290a5d057f124f18f Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 19 Mar 2024 15:49:16 +0300 Subject: [PATCH 27/57] fix dimenet links --- nablaDFT/links/models_checkpoints.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 20906f2..93c1de9 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -13,10 +13,10 @@ }, "DimeNet++": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_100k_epoch=0258.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_10k_epoch=0651.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_5k_epoch=0442.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet++/DimeNet++_dataset_train_2k_epoch=0545.ckpt" + "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_100k_epoch=0258.ckpt", + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_10k_epoch=0651.ckpt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_5k_epoch=0442.ckpt", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_2k_epoch=0545.ckpt" }, "PhiSNet": { From 0764edb511336ee6f5debbfb5eebeda710166239 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 19 Mar 2024 16:16:01 +0300 Subject: [PATCH 28/57] add graphormer links, rename graphormer model for consistency --- config/graphormer3d.yaml | 2 +- nablaDFT/links/models_checkpoints.json | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/config/graphormer3d.yaml b/config/graphormer3d.yaml index 05f6d70..a7b4d42 100644 --- a/config/graphormer3d.yaml +++ b/config/graphormer3d.yaml @@ -1,5 +1,5 @@ # Global variables -name: Graphormer3D-half +name: Graphormer3D-small dataset_name: dataset_train_2k max_steps: 1000000 warmup_steps: 60000 diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 93c1de9..890bf72 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -47,7 +47,13 @@ "Equiformer_v2": { "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_100k_epoch=010_val_loss=0.302613.ckpt", "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_10k_epoch=107_val_loss=0.337899.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckptt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckpt", "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_2k_epoch=409_val_loss=0.354093.ckpt" + }, + "Graphormer3D-small": { + "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_100k_epoch_420_val_loss_0.005773.ckpt", + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_10k_epoch_1095_val_loss_0.010159.ckpt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_5k_epoch_1331_val_loss_0.012179.ckpt", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" } } \ No newline at end of file From 8496d7ec4ecfb805d74c61530bb366a1e6b7badb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 21 Mar 2024 11:16:21 +0300 Subject: [PATCH 29/57] remove delta stuff --- generate_delta_dataset.py | 97 --------------------------------------- 1 file changed, 97 deletions(-) delete mode 100644 generate_delta_dataset.py diff --git a/generate_delta_dataset.py b/generate_delta_dataset.py deleted file mode 100644 index fea48cb..0000000 --- a/generate_delta_dataset.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -import argparse -from pathlib import Path - -from tqdm import tqdm -from ase.db import connect -from xtb.libxtb import VERBOSITY_MUTED -from xtb.interface import Calculator, Param - - -ATOMNUM_TO_ELEM = { - 1: "H", - 6: "C", - 7: "N", - 8: "O", - 9: "F", - 15: "P", - 16: "S", - 17: "Cl", - 35: "Br", - 53: "I", -} -ATOM_ENERGIES_XTB = { - "H": -0.393482763936, - "C": -1.793296371365, - "O": -3.767606950376, - "N": -2.605824161279, - "F": -4.619339964238, - "S": -3.146456870402, - "P": -2.374178794732, - "Cl": -4.482525134961, - "Br": -4.048339371234, - "I": -3.779630263390, -} -CONV_FACTOR = 0.52917720859 - - -logger = logging.getLogger(__name__) - - -def atomic_energy(atoms): - atom_symbol = [ATOMNUM_TO_ELEM[atom_num] for atom_num in atoms] - atomic_energy = [ATOM_ENERGIES_XTB[atom] for atom in atom_symbol] - return sum(atomic_energy) - -def calculate_gfn2(atoms, positions): - calc = Calculator(Param.GFN2xTB, atoms, positions / CONV_FACTOR) - calc.set_accuracy(0.0001) - calc.set_max_iterations(100) - calc.set_verbosity(VERBOSITY_MUTED) - res = calc.singlepoint() - energy = res.get_energy() - atomic_energy(atoms) - forces = res.get_gradient() * CONV_FACTOR - return energy, forces - -def generate_gfn2xtb_db(input_db_path, output_db_path): - db = connect(input_db_path) - with connect(output_db_path) as odb: - for row in tqdm(db.select(), desc="Generate GFN2-xTB database", total=len(db)): - gfn2_energy, gfn2_force = calculate_gfn2(row.numbers, row.positions) - data = {"energy": gfn2_energy, "forces": gfn2_force} - odb.write(row, data=data) - - -def generate_delta_db( - dft_db_path: str, - gfn_db_path: str, - output_db_path: str - ): - dft_db = connect(dft_db_path) - gfn_db = connect(gfn_db_path) - with connect(output_db_path) as odb: - for idx in tqdm(range(1, len(dft_db) + 1), desc="Generate Delta database", total=(len(dft_db)+1)): - row = dft_db.get(idx) # used for new db - dft_data = row.data - gfn_data = gfn_db.get(idx).data - data = { - "energy": [dft_data["energy"][0] - gfn_data["energy"]], - "forces": dft_data["forces"] - gfn_data["forces"] - } - odb.write(row, data=data) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--input_db", type=str, help="path to ASE database with atoms and positions" - ) - args, unknown = parser.parse_known_args() - suffix = Path(args.input_db).suffix - input_path_wo_suffix = args.input_db[:-len(suffix)] - gfn2xtb_db_path = input_path_wo_suffix + "_gfn2xtb.db" - delta_db_path = input_path_wo_suffix + "_delta.db" - generate_gfn2xtb_db(args.input_db, gfn2xtb_db_path) - logger.info(f"Generate GFN2-xTB database at {gfn2xtb_db_path}") - generate_delta_db(args.input_db, gfn2xtb_db_path, delta_db_path) - logger.info(f"Generate Delta database at {delta_db_path}") From 0d341cef45871d361117e460f6df519144a29735 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 21 Mar 2024 11:18:00 +0300 Subject: [PATCH 30/57] qhnet cfg changes --- config/qhnet.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/qhnet.yaml b/config/qhnet.yaml index 63a996f..495be9b 100644 --- a/config/qhnet.yaml +++ b/config/qhnet.yaml @@ -2,7 +2,7 @@ name: QHNet dataset_name: dataset_train_2k max_steps: 1000000 -warmup_steps: 10000 +warmup_steps: 0 job_type: train pretrained: False ckpt_path: null # path to checkpoint for training resume or test run From 9ce57d35a8531b86eac8e705f05f803d10acc474 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 22 Mar 2024 11:54:52 +0300 Subject: [PATCH 31/57] updated hamiltonian links --- nablaDFT/dataset/__init__.py | 2 +- nablaDFT/links/hamiltonian_databases.json | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nablaDFT/dataset/__init__.py b/nablaDFT/dataset/__init__.py index bde9da1..f055433 100644 --- a/nablaDFT/dataset/__init__.py +++ b/nablaDFT/dataset/__init__.py @@ -1,3 +1,3 @@ from .nablaDFT_dataset import ASENablaDFT, PyGNablaDFTDataModule, PyGHamiltonianDataModule -from .hamiltonian_dataset import HamiltonianDataset # database interface for Hamiltonian datasets +from .hamiltonian_dataset import HamiltonianDataset, HamiltonianDatabase # database interface for Hamiltonian datasets from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT # PyTorch Geometric interfaces for datasets \ No newline at end of file diff --git a/nablaDFT/links/hamiltonian_databases.json b/nablaDFT/links/hamiltonian_databases.json index bc6849a..3c5fc90 100644 --- a/nablaDFT/links/hamiltonian_databases.json +++ b/nablaDFT/links/hamiltonian_databases.json @@ -1,16 +1,16 @@ { "train_databases": { - "dataset_train_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_100k.db", - "dataset_train_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_10k.db", - "dataset_train_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_5k.db", - "dataset_train_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_train_2k.db" + "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_100k.db", + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_10k.db", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_5k.db", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_2k.db" }, "test_databases": { - "dataset_test_structures": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_random.db", - "dataset_test_scaffolds": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_scaffolds.db", - "dataset_test_conformations_100k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_100k_conformers.db", - "dataset_test_conformations_10k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_10k_conformers.db", - "dataset_test_conformations_5k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_5k_conformers.db", - "dataset_test_conformations_2k": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_test_2k_conformers.db" + "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_structures.db", + "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_scaffolds.db", + "dataset_test_conformations_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_100k_conformers.db", + "dataset_test_conformations_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_10k_conformers.db", + "dataset_test_conformations_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_5k_conformers.db", + "dataset_test_conformations_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_2k_conformers.db" } } \ No newline at end of file From 34c5f2067737b3ad6ba8fc6b511a5d4dee6b1008 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 22 Mar 2024 13:26:00 +0300 Subject: [PATCH 32/57] fix links for ase db --- nablaDFT/dataset/nablaDFT_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 18e7682..35ffdd8 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -60,7 +60,7 @@ def prepare_data(self): if self.split == "predict" and not exists: raise FileNotFoundError("Specified dataset not found") elif self.split != "predict" and not exists: - with open(nablaDFT.__path__[0] + "/links/energy_databases_v2.json") as f: + with open(nablaDFT.__path__[0] + "/links/energy_databases.json") as f: data = json.load(f) if self.train_ratio != 0: url = data["train_databases"][self.dataset_name] From 2b43b4ab3d9be25b6ed43fe73badd44083d4c152 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Fri, 22 Mar 2024 16:31:30 +0300 Subject: [PATCH 33/57] qhnet fix predict step --- nablaDFT/qhnet/qhnet.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index 014acb7..a0d0e05 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List import numpy as np import torch @@ -366,9 +366,15 @@ def test_step(self, batch, batch_idx): ) return loss - def predict_step(self, data): + def predict_step(self, data) -> List[torch.Tensor]: hamiltonian = self(data) - return hamiltonian + sizes = self._get_hamiltonian_sizes(data) + hamiltonians_list = [] + for idx in range(1, len(sizes)): + H = hamiltonian[sizes[idx-1]:sizes[idx], + sizes[idx-1]:sizes[idx]] + hamiltonians_list.append(H) + return hamiltonians_list def configure_optimizers(self): optimizer = self.hparams.optimizer(params=self.parameters()) @@ -452,10 +458,10 @@ def _get_batch_size(self, batch): return bsz def _get_hamiltonian_sizes(self, batch): - sizes = [] + sizes = [0] for idx in range(batch.ptr.shape[0] - 1): atoms = batch.z[batch.ptr[idx]: batch.ptr[idx + 1]] size = sum([self.net.orbital_mask[atom] for atom in atoms]) - sizes.append(size) + sizes.append(size + sum(sizes)) return sizes \ No newline at end of file From 109daa74253df11c829f654911ab072c40b6c451 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 26 Mar 2024 14:33:16 +0300 Subject: [PATCH 34/57] edit get method to remove warnings --- nablaDFT/dataset/pyg_datasets.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index d7e2131..38859ba 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -158,20 +158,20 @@ def len(self) -> int: def get(self, idx): data = self.db[idx] - z = torch.tensor(data[0]).long() - positions = torch.tensor(data[1]).to(self.dtype) + z = torch.tensor(data[0].copy()).long() + positions = torch.tensor(data[1].copy()).to(self.dtype) # see notes - hamiltonian = data[4] + hamiltonian = data[4].copy() if self.include_overlap: - overlap = data[5] + overlap = data[5].copy() else: overlap = None if self.include_core: - core = data[6] + core = data[6].copy() else: core = None - y = torch.from_numpy(data[2]).to(self.dtype) - forces = torch.from_numpy(data[3]).to(self.dtype) + y = torch.from_numpy(data[2].copy()).to(self.dtype) + forces = torch.from_numpy(data[3].copy()).to(self.dtype) data = Data( z=z, pos=positions, y=y, forces=forces, From 5b3b8190650f975cb867c91642a52abc1f7f7c82 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 26 Mar 2024 16:26:55 +0300 Subject: [PATCH 35/57] fixed hamiltonian dataset cfg, ema models, qhnet inference hamiltonian sizes --- config/datamodule/nablaDFT_hamiltonian_test.yaml | 4 ++-- nablaDFT/dataset/pyg_datasets.py | 6 +++--- nablaDFT/equiformer_v2/equiformer_v2_oc20.py | 6 +++++- nablaDFT/escn/escn.py | 4 ++++ nablaDFT/gemnet_oc/gemnet_oc.py | 4 ++++ nablaDFT/qhnet/qhnet.py | 9 +++++++-- 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/config/datamodule/nablaDFT_hamiltonian_test.yaml b/config/datamodule/nablaDFT_hamiltonian_test.yaml index 71ea2ea..47b4a66 100644 --- a/config/datamodule/nablaDFT_hamiltonian_test.yaml +++ b/config/datamodule/nablaDFT_hamiltonian_test.yaml @@ -1,7 +1,7 @@ # Dataset config for torch geometric nablaDFT -_target_: nablaDFT.dataset.PyGHamiltonianNablaDFT +_target_: nablaDFT.dataset.PyGHamiltonianDataModule -root: ./datasets/nablaDFT/test +root: ./datasets/nablaDFT/hamiltonian dataset_name: ${dataset_name} seed: 23 # Dataloader args diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 38859ba..610a4ed 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -128,9 +128,9 @@ def processed_file_names(self) -> str: def __init__( self, - datapath="database", - dataset_name="dataset_train_2k", - split: str="train", + datapath: str = "database", + dataset_name: str = "dataset_train_2k", + split: str = "train", include_hamiltonian: bool = True, include_overlap: bool = False, include_core: bool = False, diff --git a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py index 070885a..d9fe656 100644 --- a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py +++ b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py @@ -799,7 +799,11 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") - + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 for name, loss in self.hparams.losses.items(): diff --git a/nablaDFT/escn/escn.py b/nablaDFT/escn/escn.py index cded69b..8a748f2 100644 --- a/nablaDFT/escn/escn.py +++ b/nablaDFT/escn/escn.py @@ -1175,6 +1175,10 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 diff --git a/nablaDFT/gemnet_oc/gemnet_oc.py b/nablaDFT/gemnet_oc/gemnet_oc.py index e8bf524..f3e7387 100644 --- a/nablaDFT/gemnet_oc/gemnet_oc.py +++ b/nablaDFT/gemnet_oc/gemnet_oc.py @@ -1534,6 +1534,10 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + def _calculate_loss(self, y_pred, y_true) -> float: total_loss = 0.0 for name, loss in self.hparams.losses.items(): diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index a0d0e05..c07749b 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -293,7 +293,8 @@ def __init__( self.save_hyperparameters(logger=True, ignore=['net']) def forward(self, data: Data): - hamiltonian = self.net(data) + with self.ema.average_parameters(): + hamiltonian = self.net(data) return hamiltonian def step(self, batch, calculate_metrics: bool = False): @@ -402,6 +403,10 @@ def on_test_start(self) -> None: self._instantiate_ema() self._check_devices() + def on_predict_start(self) -> None: + self._instantiate_ema() + self._check_devices() + def on_validation_epoch_end(self) -> None: self._reduce_metrics(step_type="val") @@ -461,7 +466,7 @@ def _get_hamiltonian_sizes(self, batch): sizes = [0] for idx in range(batch.ptr.shape[0] - 1): atoms = batch.z[batch.ptr[idx]: batch.ptr[idx + 1]] - size = sum([self.net.orbital_mask[atom] for atom in atoms]) + size = sum([self.net.orbital_mask[atom.item()].shape[0] for atom in atoms]) sizes.append(size + sum(sizes)) return sizes \ No newline at end of file From 18a5d5b29fff053e863d2132dcc633e22043dada Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 26 Mar 2024 16:58:14 +0300 Subject: [PATCH 36/57] fix checkpoint save for qhnet --- nablaDFT/qhnet/qhnet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index c07749b..c91e2fa 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -412,7 +412,11 @@ def on_validation_epoch_end(self) -> None: def on_test_epoch_end(self) -> None: self._reduce_metrics(step_type="test") - + + def on_save_checkpoint(self, checkpoint) -> None: + with self.ema.average_parameters(): + checkpoint['state_dict'] = self.state_dict() + def _calculate_loss(self, y_pred, y_true, masks) -> float: total_loss = 0.0 for name, loss in self.hparams.losses.items(): From 12fcdf7787fd04cc67c789fbd3e84400e2d035fb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Wed, 27 Mar 2024 16:08:36 +0300 Subject: [PATCH 37/57] fix gemnet/equiformer/escn/qhnet ema during test and predict --- nablaDFT/equiformer_v2/equiformer_v2_oc20.py | 7 ++----- nablaDFT/escn/escn.py | 7 ++----- nablaDFT/gemnet_oc/gemnet_oc.py | 10 +++------- nablaDFT/qhnet/qhnet.py | 3 +-- 4 files changed, 8 insertions(+), 19 deletions(-) diff --git a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py index d9fe656..fa24169 100644 --- a/nablaDFT/equiformer_v2/equiformer_v2_oc20.py +++ b/nablaDFT/equiformer_v2/equiformer_v2_oc20.py @@ -750,8 +750,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -791,7 +790,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -837,12 +835,11 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/escn/escn.py b/nablaDFT/escn/escn.py index 8a748f2..529e196 100644 --- a/nablaDFT/escn/escn.py +++ b/nablaDFT/escn/escn.py @@ -1126,8 +1126,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -1167,7 +1166,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -1213,12 +1211,11 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/gemnet_oc/gemnet_oc.py b/nablaDFT/gemnet_oc/gemnet_oc.py index f3e7387..be99651 100644 --- a/nablaDFT/gemnet_oc/gemnet_oc.py +++ b/nablaDFT/gemnet_oc/gemnet_oc.py @@ -1431,7 +1431,6 @@ def forward(self, data: Data): def step(self, batch, calculate_metrics: bool = False): energy_out, forces_out = self.net(batch) - # TODO: temp workaround forces = batch.forces preds = {"energy": energy_out, "forces": forces_out} target = {"energy": batch.y, "forces": forces} @@ -1484,8 +1483,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -1525,7 +1523,6 @@ def on_fit_start(self) -> None: self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -1570,13 +1567,12 @@ def _reduce_metrics(self, step_type: str = "train"): self.hparams.metric.reset() def _check_devices(self): - self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) + self.hparams.metric = self.hparams.metric.to(self.device) def _instantiate_ema(self): if self.ema is not None: self.ema = self.ema(self.parameters()) + self.ema.to(self.device) def _get_batch_size(self, batch): """Function for batch size infer.""" diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index c91e2fa..3652f6a 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -293,8 +293,7 @@ def __init__( self.save_hyperparameters(logger=True, ignore=['net']) def forward(self, data: Data): - with self.ema.average_parameters(): - hamiltonian = self.net(data) + hamiltonian = self.net(data) return hamiltonian def step(self, batch, calculate_metrics: bool = False): From a3671361747adcbed65fa31a1402d0e8ec2a9243 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 28 Mar 2024 15:18:01 +0300 Subject: [PATCH 38/57] modified setup: up to py3.8, attempt to infer py version during setup --- setup.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index baaceb4..49dfa13 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,23 @@ +import sys import os import io from setuptools import setup, find_packages -CUDA = "cu121" - - def read(fname): with io.open(os.path.join(os.path.dirname(__file__), fname), encoding="utf-8") as f: return f.read() +def get_python_version(): + version_info = sys.version_info + major = version_info[0] + minor = version_info[1] + return f"cp{major}{minor}" + + +CUDA = "cu121" +PYTHON_VERSION = get_python_version() setup( @@ -20,7 +27,7 @@ def read(fname): url="https://github.com/AIRI-Institute/nablaDFT", packages=find_packages(), include_package_data=True, - python_requires=">=3.6", + python_requires=">=3.8", install_requires=[ "numpy>=1.26", "sympy==1.12", @@ -32,9 +39,9 @@ def read(fname): "pyyaml", "hydra-core==1.2.0", "torch==2.2.0", - "torch-scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", - "torch-sparse @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_sparse-0.6.18%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", - "torch-cluster @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_cluster-1.6.3%2Bpt22cu121-cp39-cp39-linux_x86_64.whl", + f"torch-scatter @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_scatter-2.1.2%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", + f"torch-sparse @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_sparse-0.6.18%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", + f"torch-cluster @ https://data.pyg.org/whl/torch-2.2.0%2Bcu121/torch_cluster-1.6.3%2Bpt22cu121-{PYTHON_VERSION}-{PYTHON_VERSION}-linux_x86_64.whl", "pytorch_lightning==2.1.4", "torch-geometric==2.4.0", "torchmetrics==1.0.1", From 10e97c25d0599dec5498963f12e458990f37fc62 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 28 Mar 2024 15:28:04 +0300 Subject: [PATCH 39/57] updated manifest to include equiformer and escn jd files --- MANIFEST.in | 2 ++ 1 file changed, 2 insertions(+) diff --git a/MANIFEST.in b/MANIFEST.in index 24c9839..7ab3152 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,4 @@ include nablaDFT/links/* include nablaDFT/data/* +include nablaDFT/equiformer_v2/Jd.pt +include nablaDFT/escn/Jd.pt \ No newline at end of file From 382f01790267268bfa96030f88a131ccee470bbf Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 28 Mar 2024 16:55:04 +0300 Subject: [PATCH 40/57] rework ase task due to ipykernel bug --- nablaDFT/ase_model/task.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/nablaDFT/ase_model/task.py b/nablaDFT/ase_model/task.py index 2e90685..96fb172 100644 --- a/nablaDFT/ase_model/task.py +++ b/nablaDFT/ase_model/task.py @@ -1,11 +1,9 @@ from typing import Any, Dict, List, Optional, Type -import pytorch_lightning as pl import schnetpack as spk import torch from schnetpack.model.base import AtomisticModel from schnetpack.task import UnsupervisedModelOutput -from torch import nn class AtomisticTaskFixed(spk.task.AtomisticTask): @@ -21,20 +19,18 @@ def __init__( scheduler_monitor: Optional[str] = None, warmup_steps: int = 0, ): - pl.LightningModule.__init__(self) + + super(AtomisticTaskFixed, self).__init__( + model=model, + outputs=outputs, + optimizer_cls=optimizer_cls, + optimizer_args=optimizer_args, + scheduler_cls=scheduler_cls, + scheduler_args=scheduler_args, + scheduler_monitor=scheduler_monitor, + warmup_steps=warmup_steps + ) self.model_name = model_name - self.model = model - self.optimizer_cls = optimizer_cls - self.optimizer_kwargs = optimizer_args - self.scheduler_cls = scheduler_cls - self.scheduler_kwargs = scheduler_args - self.schedule_monitor = scheduler_monitor - self.outputs = nn.ModuleList(outputs) - - self.grad_enabled = len(self.model.required_derivatives) > 0 - self.lr = optimizer_args["lr"] - self.warmup_steps = warmup_steps - self.save_hyperparameters(ignore=["model"]) def predict_step(self, batch, batch_idx): torch.set_grad_enabled(self.grad_enabled) From 700d5e6d23d595d9e52b268016cf8216badc99f0 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 2 Apr 2024 16:07:27 +0300 Subject: [PATCH 41/57] minor changes to painn --- config/model/painn-oc.yaml | 13 ++++++------- nablaDFT/painn_pyg/painn.py | 7 +++---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/config/model/painn-oc.yaml b/config/model/painn-oc.yaml index 5c21880..1ec9b30 100644 --- a/config/model/painn-oc.yaml +++ b/config/model/painn-oc.yaml @@ -3,10 +3,10 @@ _target_: nablaDFT.painn_pyg.PaiNNLightning model_name: "PAINN-OC" net: _target_: nablaDFT.painn_pyg.PaiNN - hidden_channels: 512 + hidden_channels: 128 num_layers: 6 - num_rbf: 128 - cutoff: 12.0 + num_rbf: 100 + cutoff: 5.0 max_neighbors: 50 rbf: name: 'gaussian' @@ -23,9 +23,7 @@ net: optimizer: _target_: torch.optim.AdamW _partial_: true - amsgrad: true - betas: [0.9, 0.95] - lr: 1e-3 + lr: 1e-4 weight_decay: 0 lr_scheduler: @@ -33,6 +31,7 @@ lr_scheduler: _partial_: true factor: 0.8 patience: 10 + min_lr: 1e-6 losses: energy: @@ -41,7 +40,7 @@ losses: _target_: nablaDFT.gemnet_oc.loss.L2Loss loss_coefs: energy: 1.0 - forces: 100.0 + forces: 1.0 ema: _target_: torch_ema.ExponentialMovingAverage diff --git a/nablaDFT/painn_pyg/painn.py b/nablaDFT/painn_pyg/painn.py index 1ef8bd7..ce2d337 100644 --- a/nablaDFT/painn_pyg/painn.py +++ b/nablaDFT/painn_pyg/painn.py @@ -4,7 +4,7 @@ import torch from torch import nn -from torch_geometric.nn import MessagePassing +from torch_geometric.nn import MessagePassing, radius_graph from torch_scatter import scatter, segment_coo from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -24,7 +24,6 @@ compute_neighbors, get_edge_id, get_pbc_distances, - radius_graph, radius_graph_pbc, repeat_blocks, load_scales_compat, @@ -389,9 +388,9 @@ def forward(self, data): else: forces = -1 * ( torch.autograd.grad( - per_atom_energy, + energy, pos, - grad_outputs=torch.ones_like(per_atom_energy), + grad_outputs=torch.ones_like(energy), create_graph=self.training, )[0] ) From 2d32361857626c5559a2b2a12086f418bc885063 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 15 Apr 2024 12:11:34 +0300 Subject: [PATCH 42/57] fixed bug when no postprocessors are present --- nablaDFT/ase_model/task.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nablaDFT/ase_model/task.py b/nablaDFT/ase_model/task.py index 96fb172..2aecc23 100644 --- a/nablaDFT/ase_model/task.py +++ b/nablaDFT/ase_model/task.py @@ -51,10 +51,11 @@ def predict_step(self, batch, batch_idx): def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # reshape model.postprocessors (AddOffsets) # otherwise during test error will occur - checkpoint["state_dict"]["model.postprocessors.0.mean"] = checkpoint[ - "state_dict" - ]["model.postprocessors.0.mean"].reshape(1) - + if checkpoint["state_dict"].get("model.postprocessors.0.mean", None): + checkpoint["state_dict"]["model.postprocessors.0.mean"] = checkpoint[ + "state_dict" + ]["model.postprocessors.0.mean"].reshape(1) + # override base class method def predict_without_postprocessing(self, batch): pred = self(batch) From aba850cc42416a3e2346701f5c416a622b0a8c05 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 15 Apr 2024 12:13:50 +0300 Subject: [PATCH 43/57] fixed bug for hamiltonian dataset for DDP --- nablaDFT/dataset/pyg_datasets.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 610a4ed..bc6a5c8 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -191,9 +191,6 @@ def download(self) -> None: with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading split: {self.dataset_name}") as t: request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) - def process(self) -> None: - pass - def _get_max_orbitals(self, datapath, dataset_name): db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) if not os.path.exists(db_path): From 3caa55fa5d8d903425b6ea9c44512069fc81d180 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 15 Apr 2024 15:23:15 +0300 Subject: [PATCH 44/57] fixed PyG PaiNN implementation --- config/model/painn-oc.yaml | 15 +-- nablaDFT/painn_pyg/painn.py | 179 ++++++++++++++---------------------- 2 files changed, 72 insertions(+), 122 deletions(-) diff --git a/config/model/painn-oc.yaml b/config/model/painn-oc.yaml index 1ec9b30..c13ff17 100644 --- a/config/model/painn-oc.yaml +++ b/config/model/painn-oc.yaml @@ -7,7 +7,7 @@ net: num_layers: 6 num_rbf: 100 cutoff: 5.0 - max_neighbors: 50 + max_neighbors: 100 rbf: name: 'gaussian' envelope: @@ -17,20 +17,20 @@ net: direct_forces: false use_pbc: false otf_graph: true - num_elements: 65 + num_elements: 100 optimizer: _target_: torch.optim.AdamW _partial_: true - lr: 1e-4 + lr: 5e-4 weight_decay: 0 lr_scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true factor: 0.8 - patience: 10 + patience: 100 min_lr: 1e-6 losses: @@ -42,11 +42,6 @@ loss_coefs: energy: 1.0 forces: 1.0 -ema: - _target_: torch_ema.ExponentialMovingAverage - _partial_: true - decay: 0.9999 - metric: _target_: torchmetrics.MultitaskWrapper _convert_: all @@ -54,4 +49,4 @@ metric: energy: _target_: torchmetrics.MeanAbsoluteError forces: - _target_: torchmetrics.MeanAbsoluteError + _target_: torchmetrics.MeanAbsoluteError \ No newline at end of file diff --git a/nablaDFT/painn_pyg/painn.py b/nablaDFT/painn_pyg/painn.py index ce2d337..aad0327 100644 --- a/nablaDFT/painn_pyg/painn.py +++ b/nablaDFT/painn_pyg/painn.py @@ -17,16 +17,15 @@ RadialBasis, ScaledSiLU, ScaleFactor, + CosineCutoff ) -# from torch_geometric.nn import radius_graph from .utils import ( compute_neighbors, get_edge_id, get_pbc_distances, radius_graph_pbc, repeat_blocks, - load_scales_compat, ) @@ -53,7 +52,6 @@ def __init__( use_pbc: bool = True, otf_graph: bool = True, num_elements: int = 83, - scale_file: Optional[str] = None, ) -> None: super(PaiNN, self).__init__() @@ -89,22 +87,75 @@ def __init__( PaiNNMessage(hidden_channels, num_rbf).jittable() ) self.update_layers.append(PaiNNUpdate(hidden_channels)) - setattr(self, "upd_out_scalar_scale_%d" % i, ScaleFactor()) self.out_energy = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels // 2, 1), ) if self.regress_forces is True and self.direct_forces is True: self.out_forces = PaiNNOutput(hidden_channels) + self.reset_parameters() - self.inv_sqrt_2 = 1 / math.sqrt(2.0) - self.reset_parameters() + @torch.enable_grad() + def forward(self, data): + pos = data.pos + batch = data.batch + z = data.z.long() - load_scales_compat(self, scale_file) + if self.regress_forces and not self.direct_forces: + pos = pos.requires_grad_(True) + + ( + edge_index, + neighbors, + edge_dist, + edge_vector, + id_swap, + ) = self.generate_graph_values(data) + + assert z.dim() == 1 and z.dtype == torch.long + + edge_rbf = self.radial_basis(edge_dist) # rbf * envelope + + x = self.atom_emb(z) + vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) + + #### Interaction blocks ############################################### + + for i in range(self.num_layers): + dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) + + x = x + dx + vec = vec + dvec + + dx, dvec = self.update_layers[i](x, vec) + + x = x + dx + vec = vec + dvec + + #### Output block ##################################################### + per_atom_energy = self.out_energy(x).squeeze(1) + energy = scatter(per_atom_energy, batch, dim=0) + + if self.regress_forces: + if self.direct_forces: + forces = self.out_forces(x, vec) + return energy, forces + else: + forces = -1 * ( + torch.autograd.grad( + energy, + pos, + grad_outputs=torch.ones_like(energy), + create_graph=self.training, + )[0] + ) + return energy, forces + else: + return energy def reset_parameters(self) -> None: nn.init.xavier_uniform_(self.out_energy[0].weight) @@ -307,11 +358,6 @@ def generate_graph_values(self, data): empty_image = neighbors == 0 if torch.any(empty_image): print(f"An image has no neighbors! #images = {empty_image.sum().item()}") - # raise ValueError( - # f"An image has no neighbors: id={data.id[empty_image]}, " - # f"sid={data.sid[empty_image]}, fid={data.fid[empty_image]}" - # ) - # Symmetrize edges for swapping in symmetric message passing ( edge_index, @@ -337,67 +383,6 @@ def generate_graph_values(self, data): id_swap, ) - - @torch.enable_grad() - def forward(self, data): - pos = data.pos - batch = data.batch - z = data.z.long() - - if self.regress_forces and not self.direct_forces: - pos = pos.requires_grad_(True) - - ( - edge_index, - neighbors, - edge_dist, - edge_vector, - id_swap, - ) = self.generate_graph_values(data) - - assert z.dim() == 1 and z.dtype == torch.long - - edge_rbf = self.radial_basis(edge_dist) # rbf * envelope - - x = self.atom_emb(z) - vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) - - #### Interaction blocks ############################################### - - for i in range(self.num_layers): - dx, dvec = self.message_layers[i](x, vec, edge_index, edge_rbf, edge_vector) - - x = x + dx - vec = vec + dvec - x = x * self.inv_sqrt_2 - - dx, dvec = self.update_layers[i](x, vec) - - x = x + dx - vec = vec + dvec - x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) - - #### Output block ##################################################### - per_atom_energy = self.out_energy(x).squeeze(1) - energy = scatter(per_atom_energy, batch, dim=0) - - if self.regress_forces: - if self.direct_forces: - forces = self.out_forces(x, vec) - return energy, forces - else: - forces = -1 * ( - torch.autograd.grad( - energy, - pos, - grad_outputs=torch.ones_like(energy), - create_graph=self.training, - )[0] - ) - return energy, forces - else: - return energy - def _generate_graph( self, data, @@ -470,7 +455,6 @@ def _generate_graph( j, i = edge_index distance_vec = data.pos[j] - data.pos[i] edge_dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt() - # edge_dist = distance_vec.norm(dim=-1) cell_offsets = torch.zeros(edge_index.shape[1], 3, device=data.pos.device) cell_offset_distances = torch.zeros_like( cell_offsets, device=data.pos.device @@ -513,15 +497,10 @@ def __init__( self.x_proj = nn.Sequential( nn.Linear(hidden_channels, hidden_channels), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) self.rbf_proj = nn.Linear(num_rbf, hidden_channels * 3) - - self.inv_sqrt_3 = 1 / math.sqrt(3.0) - self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.x_layernorm = nn.LayerNorm(hidden_channels) - self.reset_parameters() def reset_parameters(self) -> None: @@ -531,10 +510,9 @@ def reset_parameters(self) -> None: self.x_proj[2].bias.data.fill_(0) nn.init.xavier_uniform_(self.rbf_proj.weight) self.rbf_proj.bias.data.fill_(0) - self.x_layernorm.reset_parameters() def forward(self, x, vec, edge_index, edge_rbf, edge_vector): - xh = self.x_proj(self.x_layernorm(x)) + xh = self.x_proj(x) # TODO(@abhshkdz): Nans out with AMP here during backprop. Debug / fix. rbfh = self.rbf_proj(edge_rbf) @@ -553,10 +531,7 @@ def forward(self, x, vec, edge_index, edge_rbf, edge_vector): def message(self, xh_j, vec_j, rbfh_ij, r_ij): x, xh2, xh3 = torch.split(xh_j * rbfh_ij, self.hidden_channels, dim=-1) - xh2 = xh2 * self.inv_sqrt_3 - vec = vec_j * xh2.unsqueeze(1) + xh3.unsqueeze(1) * r_ij.unsqueeze(2) - vec = vec * self.inv_sqrt_h return x, vec @@ -586,13 +561,9 @@ def __init__(self, hidden_channels) -> None: self.vec_proj = nn.Linear(hidden_channels, hidden_channels * 2, bias=False) self.xvec_proj = nn.Sequential( nn.Linear(hidden_channels * 2, hidden_channels), - ScaledSiLU(), + nn.SiLU(), nn.Linear(hidden_channels, hidden_channels * 3), ) - - self.inv_sqrt_2 = 1 / math.sqrt(2.0) - self.inv_sqrt_h = 1 / math.sqrt(hidden_channels) - self.reset_parameters() def reset_parameters(self) -> None: @@ -604,7 +575,7 @@ def reset_parameters(self) -> None: def forward(self, x, vec): vec1, vec2 = torch.split(self.vec_proj(vec), self.hidden_channels, dim=-1) - vec_dot = (vec1 * vec2).sum(dim=1) * self.inv_sqrt_h + vec_dot = (vec1 * vec2).sum(dim=1) # NOTE: Can't use torch.norm because the gradient is NaN for input = 0. # Add an epsilon offset to make sure sqrt is always positive. @@ -614,7 +585,6 @@ def forward(self, x, vec): xvec1, xvec2, xvec3 = torch.split(x_vec_h, self.hidden_channels, dim=-1) dx = xvec1 + xvec2 * vec_dot - dx = dx * self.inv_sqrt_2 dvec = xvec3.unsqueeze(1) * vec1 @@ -701,14 +671,12 @@ def __init__( optimizer: Optimizer, lr_scheduler: LRScheduler, losses: Dict, - ema, metric, loss_coefs ) -> None: super(PaiNNLightning, self).__init__() - self.ema = ema self.net = net - self.save_hyperparameters(logger=True, ignore=["net", "ema"]) + self.save_hyperparameters(logger=True, ignore=["net"]) def forward(self, data): energy, forces = self.net(data) @@ -747,8 +715,7 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "val/loss", loss, @@ -773,8 +740,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -806,15 +772,10 @@ def configure_optimizers(self): } return {"optimizer": optimizer} - def on_before_zero_grad(self, optimizer: Optimizer) -> None: - self.ema.update() - def on_fit_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_test_start(self) -> None: - self._instantiate_ema() self._check_devices() def on_validation_epoch_end(self) -> None: @@ -854,13 +815,7 @@ def _reduce_metrics(self, step_type: str = "train"): def _check_devices(self): self.hparams.metric = self.hparams.metric.to(self.device) - if self.ema is not None: - self.ema.to(self.device) - - def _instantiate_ema(self): - if self.ema is not None: - self.ema = self.ema(self.parameters()) - + def _get_batch_size(self, batch): """Function for batch size infer.""" bsz = batch.batch.max().detach().item() + 1 # get batch size From dedc4e52cee7fda3770d57510aacdb742161fa14 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 15 Apr 2024 15:28:33 +0300 Subject: [PATCH 45/57] remove import --- nablaDFT/painn_pyg/painn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nablaDFT/painn_pyg/painn.py b/nablaDFT/painn_pyg/painn.py index aad0327..b96fa01 100644 --- a/nablaDFT/painn_pyg/painn.py +++ b/nablaDFT/painn_pyg/painn.py @@ -16,8 +16,6 @@ AtomEmbedding, RadialBasis, ScaledSiLU, - ScaleFactor, - CosineCutoff ) from .utils import ( From 928cc86ad5ffceacd1da4940378e7371e29f0bd4 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 16 Apr 2024 18:56:54 +0300 Subject: [PATCH 46/57] rework graphormer preds functions --- nablaDFT/graphormer/graphormer_3d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nablaDFT/graphormer/graphormer_3d.py b/nablaDFT/graphormer/graphormer_3d.py index 8a7fb2d..2ec20c4 100644 --- a/nablaDFT/graphormer/graphormer_3d.py +++ b/nablaDFT/graphormer/graphormer_3d.py @@ -450,7 +450,9 @@ def test_step(self, batch, batch_idx): return loss def predict_step(self, data, **kwargs): - energy_out, forces_out, _ = self(data) + """Note: predictions output consistent with PyG networks""" + energy_out, forces_out, mask = self(data) + forces_out = torch.masked_select(forces_out, mask).reshape(-1, 3) return energy_out, forces_out def configure_optimizers(self): From 3c80540093e1e2542c0be8edfb8320817c0c1f6d Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 23 Apr 2024 12:54:43 +0300 Subject: [PATCH 47/57] updated docstrings, change HamiltonianDatabase interface --- nablaDFT/dataset/__init__.py | 16 ++- nablaDFT/dataset/atoms_datamodule.py | 63 +++------- nablaDFT/dataset/hamiltonian_dataset.py | 35 +++++- nablaDFT/dataset/nablaDFT_dataset.py | 140 +++++++++++++-------- nablaDFT/dataset/pyg_datasets.py | 68 +++++++--- nablaDFT/dataset/split.py | 2 + nablaDFT/optimization/calculator.py | 59 ++++----- nablaDFT/optimization/opt_utils.py | 5 + nablaDFT/optimization/optimizers.py | 3 +- nablaDFT/optimization/pyg_ase_interface.py | 56 ++++----- nablaDFT/optimization/task.py | 20 +-- nablaDFT/pipelines.py | 14 +++ nablaDFT/utils.py | 18 ++- 13 files changed, 293 insertions(+), 206 deletions(-) diff --git a/nablaDFT/dataset/__init__.py b/nablaDFT/dataset/__init__.py index f055433..28fdb0a 100644 --- a/nablaDFT/dataset/__init__.py +++ b/nablaDFT/dataset/__init__.py @@ -1,3 +1,13 @@ -from .nablaDFT_dataset import ASENablaDFT, PyGNablaDFTDataModule, PyGHamiltonianDataModule -from .hamiltonian_dataset import HamiltonianDataset, HamiltonianDatabase # database interface for Hamiltonian datasets -from .pyg_datasets import PyGNablaDFT, PyGHamiltonianNablaDFT # PyTorch Geometric interfaces for datasets \ No newline at end of file +from .nablaDFT_dataset import ( + ASENablaDFT, + PyGNablaDFTDataModule, + PyGHamiltonianDataModule, +) # PyTorch Lightning interfaces for datasets +from .hamiltonian_dataset import ( + HamiltonianDataset, + HamiltonianDatabase, +) # database interface for Hamiltonian datasets +from .pyg_datasets import ( + PyGNablaDFT, + PyGHamiltonianNablaDFT, +) # PyTorch Geometric interfaces for datasets diff --git a/nablaDFT/dataset/atoms_datamodule.py b/nablaDFT/dataset/atoms_datamodule.py index b3130a5..88c58ee 100644 --- a/nablaDFT/dataset/atoms_datamodule.py +++ b/nablaDFT/dataset/atoms_datamodule.py @@ -1,8 +1,6 @@ -# Overrided AtomsDataModule from SchNetPack -from copy import copy +"""Overrided AtomsDataModule from SchNetPack""" from typing import List, Dict, Optional, Union -import pytorch_lightning as pl import torch from schnetpack.data import ( @@ -10,14 +8,17 @@ load_dataset, AtomsLoader, SplittingStrategy, - AtomsDataModule + AtomsDataModule, ) class AtomsDataModule(AtomsDataModule): """ - Overrided AtomsDataModule from SchNetPack with additional - methods for prediction task. + Overrided AtomsDataModule from SchNetPack with predict_dataloader + method and overrided setup for prediction task. + + Args: + split (str): string contains type of task/dataset, must be one of ['train', 'test', 'predict'] """ def __init__( @@ -47,44 +48,6 @@ def __init__( splitting: Optional[SplittingStrategy] = None, pin_memory: Optional[bool] = False, ): - """ - Args: - split: string contains type of task/dataset, must be one of ['train', 'test', 'predict'] - datapath: path to dataset - batch_size: (train) batch size - num_train: number of training examples (absolute or relative) - num_val: number of validation examples (absolute or relative) - num_test: number of test examples (absolute or relative) - split_file: path to npz file with data partitions - format: dataset format - load_properties: subset of properties to load - val_batch_size: validation batch size. If None, use test_batch_size, then - batch_size. - test_batch_size: test batch size. If None, use val_batch_size, then - batch_size. - transforms: Preprocessing transform applied to each system separately before - batching. - train_transforms: Overrides transform_fn for training. - val_transforms: Overrides transform_fn for validation. - test_transforms: Overrides transform_fn for testing. - num_workers: Number of data loader workers. - num_val_workers: Number of validation data loader workers - (overrides num_workers). - num_test_workers: Number of test data loader workers - (overrides num_workers). - property_units: Dictionary from property to corresponding unit as a string - (eV, kcal/mol, ...). - distance_unit: Unit of the atom positions and cell as a string - (Ang, Bohr, ...). - data_workdir: Copy data here as part of setup, e.g. to a local file - system for faster performance. - cleanup_workdir_stage: Determines after which stage to remove the data - workdir - splitting: Method to generate train/validation/test partitions - (default: RandomSplit) - pin_memory: If true, pin memory of loaded data to GPU. Default: Will be - set to true, when GPUs are used. - """ super().__init__( datapath=datapath, batch_size=batch_size, @@ -108,12 +71,17 @@ def __init__( data_workdir=data_workdir, cleanup_workdir_stage=cleanup_workdir_stage, splitting=splitting, - pin_memory=pin_memory + pin_memory=pin_memory, ) self.split = split self._predict_dataloader = None - + def setup(self, stage: Optional[str] = None): + """Overrided method from original AtomsDataModule class + + Args: + stage (str): trainer stage, must be one of ['fit', 'test', 'predict'] + """ # check whether data needs to be copied if self.data_workdir is None: datapath = self.datapath @@ -143,12 +111,13 @@ def setup(self, stage: Optional[str] = None): self._setup_transforms() def predict_dataloader(self) -> AtomsLoader: + """Describes predict dataloader, used for prediction task""" if self._predict_dataloader is None: self._predict_dataloader = AtomsLoader( self.test_dataset, batch_size=self.test_batch_size, num_workers=self.num_test_workers, pin_memory=self._pin_memory, - shuffle=False + shuffle=False, ) return self._predict_dataloader diff --git a/nablaDFT/dataset/hamiltonian_dataset.py b/nablaDFT/dataset/hamiltonian_dataset.py index 9ca89d7..1361ad7 100644 --- a/nablaDFT/dataset/hamiltonian_dataset.py +++ b/nablaDFT/dataset/hamiltonian_dataset.py @@ -22,7 +22,11 @@ class HamiltonianDatabase: C (Norb, Norb) core hamiltonian in atomic units moses_id () (int) molecule id in MOSES dataset conformer_id () (int) conformation id + + Args: + filename (str): path to database. """ + def __init__(self, filename, flags=apsw.SQLITE_OPEN_READONLY): self.db = filename self.connections = {} # allow multiple connections (needed for multi-threading) @@ -38,12 +42,24 @@ def __getitem__(self, idx): data = cursor.execute( """SELECT * FROM data WHERE id IN (""" + str(idx)[1:-1] + ")" ).fetchall() - return [self._unpack_data_tuple(i) for i in data] + ids = cursor.execute( + """SELECT * FROM dataset_ids WHERE id IN (""" + str(idx)[1:-1] + ")" + ).fetchall() + moses_ids, conformer_ids = [i[1] for i in ids], [i[2] for i in ids] + unpacked_data = [ + (*self._unpack_data_tuple(chunk), moses_ids[i], conformer_ids[i]) + for i, chunk in enumerate(data) + ] + return unpacked_data else: data = cursor.execute( """SELECT * FROM data WHERE id=""" + str(idx) ).fetchone() - return self._unpack_data_tuple(data) + ids = cursor.execute( + """SELECT * FROM dataset_ids WHERE id=""" + str(idx) + ).fetchall()[0] + moses_id, conformer_id = ids[1], ids[2] + return (*self._unpack_data_tuple(data), moses_id, conformer_id) def _unpack_data_tuple(self, data): N = ( @@ -244,8 +260,21 @@ def Z(self): return self._deblob(data[2], dtype=np.int32, shape=(N,)) - class HamiltonianDataset(torch.utils.data.Dataset): + """PyTorch interface for nablaDFT Hamiltonian databases. + + Collates hamiltonian, overlap and core hamiltonian matrices + from batch into block diagonal matrix. + + Args: + - filepath (str): path to database. + - max_batch_orbitals (int): maximum number of orbitals in one batch. + - max_batch_atoms (int): maximum number of atoms in one batch.. + - max_squares (int): maximum size of block diagonal matrix in one batch. + - subset() path to saved numpy array with selected indexes. + - dtype (torch.dtype): defines torch.dtype for tensors. + """ + def __init__( self, filepath, diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 35ffdd8..c771470 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -1,4 +1,5 @@ """Module defines Pytorch Lightning DataModule interfaces for various NablaDFT datasets""" + import json import os from typing import Optional @@ -18,6 +19,20 @@ class ASENablaDFT(AtomsDataModule): + """PytorchLightning interface for nablaDFT ASE datasets. + + Args: + split (str): type of split, must be one of ['train', 'test', 'predict']. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + datapath (str): path to existing dataset directory or location for download. + train_ratio (float): dataset part used for training. + val_ratio (float): dataset part used for validation. + test_ratio (float): dataset part used for test or prediction. + train_transforms (Callable): data transform, called for every sample in training dataset. + val_transforms (Callable): data transform, called for every sample in validation dataset. + test_transforms (Callable): data transform, called for every sample in test dataset. + """ + def __init__( self, split: str, @@ -34,6 +49,7 @@ def __init__( format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, **kwargs, ): + """""" super().__init__( split=split, datapath=datapath, @@ -92,26 +108,27 @@ def prepare_data(self): class PyGDataModule(LightningDataModule): - """Parent class which encapsulates PyG dataset for use with Pytorch Lightning Trainer. + """Parent class which encapsulates PyG dataset to use with Pytorch Lightning Trainer. In order to add new dataset variant, define children class with setup() method. Args: - - root (str): path to directory with r'raw/' subfolder with existing dataset or download location. - - dataset_name (str): split name from links .json file. - - train_size (float): part of dataset used for training, must be in [0, 1]. - - val_size (float): part of dataset used for validation, must be in [0, 1]. - - seed (int): seed number, used for torch.Generator object during train/val split. - - kwargs (Dict): other arguments for dataset. + root (str): path to directory with r'raw/' subfolder with existing dataset or download location. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + train_size (float): part of dataset used for training, must be in [0, 1]. + val_size (float): part of dataset used for validation, must be in [0, 1]. + seed (int): seed number, used for torch.Generator object during train/val split. + **kwargs: arguments for torch.DataLoader. """ + def __init__( - self, - root: str, - dataset_name: str, - train_size: float = 0.9, - val_size: float = 0.1, - seed: int = 23, - **kwargs - ) -> None: + self, + root: str, + dataset_name: str, + train_size: float = 0.9, + val_size: float = 0.1, + seed: int = 23, + **kwargs, + ) -> None: super().__init__() self.dataset_train = None self.dataset_val = None @@ -122,16 +139,6 @@ def __init__( self.dataset_name = dataset_name self.seed = seed self.sizes = [train_size, val_size] - dataloader_keys = [ - "batch_size", "num_workers", - "pin_memory", "persistent_workers" - ] - self.dataloader_kwargs = {} - for key in dataloader_keys: - val = kwargs.get(key, None) - self.dataloader_kwargs[key] = val - if val is not None: - del kwargs[key] self.kwargs = kwargs def dataloader(self, dataset, **kwargs): @@ -139,58 +146,83 @@ def dataloader(self, dataset, **kwargs): def setup(self, stage: str) -> None: raise NotImplementedError - + def train_dataloader(self): - return self.dataloader(self.dataset_train, shuffle=True, **self.dataloader_kwargs) - + return self.dataloader(self.dataset_train, shuffle=True, **self.kwargs) + def val_dataloader(self): - return self.dataloader(self.dataset_val, shuffle=False, **self.dataloader_kwargs) - + return self.dataloader(self.dataset_val, shuffle=False, **self.kwargs) + def test_dataloader(self): - return self.dataloader(self.dataset_test, shuffle=False, **self.dataloader_kwargs) - + return self.dataloader(self.dataset_test, shuffle=False, **self.kwargs) + def predict_dataloader(self): - return self.dataloader(self.dataset_predict, shuffle=False, **self.dataloader_kwargs) + return self.dataloader(self.dataset_predict, shuffle=False, **self.kwargs) class PyGHamiltonianDataModule(PyGDataModule): - """DataModule for Hamiltonian NablaDFT dataset - + """DataModule for Hamiltonian nablaDFT dataset, subclass of PyGDataModule. + Keyword arguments: - - hamiltonian (bool): retrieve from database molecule's full hamiltonian matrix. True by default. - - include_overlap (bool): retrieve from database molecule's overlab matrix. - - include_core (bool): retrieve from databaes molecule's core hamiltonian matrix. + hamiltonian (bool): retrieve from database molecule's full hamiltonian matrix. True by default. + include_overlap (bool): retrieve from database molecule's overlab matrix. + include_core (bool): retrieve from databaes molecule's core hamiltonian matrix. + **kwargs: arguments for torch.DataLoader and PyGDataModule instance. See PyGDatamodule docs. """ + def __init__( - self, - root: str, - dataset_name: str, - train_size: float = None, - val_size: float = None, - **kwargs) -> None: + self, + root: str, + dataset_name: str, + train_size: float = None, + val_size: float = None, + **kwargs, + ) -> None: super().__init__(root, dataset_name, train_size, val_size, **kwargs) def setup(self, stage: str) -> None: if stage == "fit": - dataset = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) - self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, - generator=torch.Generator().manual_seed(self.seed)) + dataset = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "train", **self.kwargs + ) + self.dataset_train, self.dataset_val = random_split( + dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed) + ) elif stage == "test": - self.dataset_test = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "test", **self.kwargs) + self.dataset_test = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "test", **self.kwargs + ) elif stage == "predict": - self.dataset_predict = PyGHamiltonianNablaDFT(self.root, self.dataset_name, "predict", **self.kwargs) + self.dataset_predict = PyGHamiltonianNablaDFT( + self.root, self.dataset_name, "predict", **self.kwargs + ) class PyGNablaDFTDataModule(PyGDataModule): - def __init__(self, root: str, dataset_name: str, train_size: float = None, val_size: float = None, **kwargs) -> None: + """DataModule for nablaDFT dataset, subclass of PyGDataModule. + See PyGDatamodule doc.""" + + def __init__( + self, + root: str, + dataset_name: str, + train_size: float = None, + val_size: float = None, + **kwargs, + ) -> None: super().__init__(root, dataset_name, train_size, val_size, **kwargs) def setup(self, stage: str) -> None: if stage == "fit": dataset = PyGNablaDFT(self.root, self.dataset_name, "train", **self.kwargs) - self.dataset_train, self.dataset_val = random_split(dataset, self.sizes, - generator=torch.Generator().manual_seed(self.seed)) + self.dataset_train, self.dataset_val = random_split( + dataset, self.sizes, generator=torch.Generator().manual_seed(self.seed) + ) elif stage == "test": - self.dataset_test = PyGNablaDFT(self.root, self.dataset_name, "test", **self.kwargs) + self.dataset_test = PyGNablaDFT( + self.root, self.dataset_name, "test", **self.kwargs + ) elif stage == "predict": - self.dataset_predict = PyGNablaDFT(self.root, self.dataset_name, "predict", **self.kwargs) + self.dataset_predict = PyGNablaDFT( + self.root, self.dataset_name, "predict", **self.kwargs + ) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index bc6a5c8..cc24998 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -1,4 +1,5 @@ """Module describes PyTorch Geometric interfaces for various NablaDFT datasets""" + import json import os import logging @@ -19,8 +20,15 @@ class PyGNablaDFT(InMemoryDataset): - """Dataset adapter for ASE2PyG conversion. + """Pytorch Geometric interface for nablaDFT datasets. Based on https://github.com/atomicarchitects/equiformer/blob/master/datasets/pyg/md17.py + + Args: + datapath (str): path to existing dataset directory or location for download. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + split (str): type of split, must be one of ['train', 'test', 'predict']. + transform (Callable): callable data transform, called on every access to element. + pre_transform (Callable): callable data transform, called on every access to element. """ db_suffix = ".db" @@ -74,8 +82,17 @@ def download(self) -> None: data = json.load(f) url = data[f"{self.split}_databases"][self.dataset_name] file_size = get_file_size(url) - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading split: {self.dataset_name}") as t: - request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) + with tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + total=file_size, + desc=f"Downloading split: {self.dataset_name}", + ) as t: + request.urlretrieve( + url, self.raw_paths[0], reporthook=tqdm_download_hook(t) + ) def process(self) -> None: db = connect(self.raw_paths[0]) @@ -99,23 +116,25 @@ def process(self) -> None: class PyGHamiltonianNablaDFT(Dataset): - """Pytorch Geometric dataset for NablaDFT Hamiltonian database. + """Pytorch Geometric interface for nablaDFT Hamiltonian datasets. Args: - - datapath (str): path to existing dataset directory or location for download. - - dataset_name (str): split name from links .json. - - split (str): type of split, must be one of ['train', 'test', 'predict']. - - include_hamiltonian (bool): if True, retrieves full Hamiltonian matrices from database. - - include_overlap (bool): if True, retrieves overlap matrices from database. - - include_core (bool): if True, retrieves core Hamiltonian matrices from database. - - dtype (torch.dtype): defines torch.dtype for energy, positions, forces tensors. - - transform (Callable): callable data transform, called on every access to element. - - pre_transform (Callable): callable data transform, called on every access to element. + datapath (str): path to existing dataset directory or location for download. + dataset_name (str): split name from links .json or filename of existing file from datapath directory. + split (str): type of split, must be one of ['train', 'test', 'predict']. + include_hamiltonian (bool): if True, retrieves full Hamiltonian matrices from database. + include_overlap (bool): if True, retrieves overlap matrices from database. + include_core (bool): if True, retrieves core Hamiltonian matrices from database. + dtype (torch.dtype): defines torch.dtype for energy, positions and forces tensors. + transform (Callable): callable data transform, called on every access to element. + pre_transform (Callable): callable data transform, called on every access to element. + Note: Hamiltonian matrix for each molecule has different shape. PyTorch Geometric tries to concatenate each torch.Tensor in batch, so in order to make batch from data we leave all hamiltonian matrices in numpy array form. During train, these matrices will be yield as List[np.array]. """ + db_suffix = ".db" @property @@ -173,8 +192,10 @@ def get(self, idx): y = torch.from_numpy(data[2].copy()).to(self.dtype) forces = torch.from_numpy(data[3].copy()).to(self.dtype) data = Data( - z=z, pos=positions, - y=y, forces=forces, + z=z, + pos=positions, + y=y, + forces=forces, hamiltonian=hamiltonian, overlap=overlap, core=core, @@ -185,11 +206,20 @@ def get(self, idx): def download(self) -> None: with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: - data = json.load(f) - url = data["train_databases"][self.dataset_name] + data = json.load(f) + url = data["train_databases"][self.dataset_name] file_size = get_file_size(url) - with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, total=file_size, desc=f"Downloading split: {self.dataset_name}") as t: - request.urlretrieve(url, self.raw_paths[0], reporthook=tqdm_download_hook(t)) + with tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + total=file_size, + desc=f"Downloading split: {self.dataset_name}", + ) as t: + request.urlretrieve( + url, self.raw_paths[0], reporthook=tqdm_download_hook(t) + ) def _get_max_orbitals(self, datapath, dataset_name): db_path = os.path.join(datapath, "raw/" + dataset_name + self.db_suffix) diff --git a/nablaDFT/dataset/split.py b/nablaDFT/dataset/split.py index df9e8c4..b049d45 100644 --- a/nablaDFT/dataset/split.py +++ b/nablaDFT/dataset/split.py @@ -40,6 +40,8 @@ def split(self, dataset, *split_sizes): class TestSplit(SplittingStrategy): """Splitting strategy that allows to put all dataset elements in test split without index permutations. + Used for schnetpack datasets to overcome limitation + when train and val split are empty. """ def split(self, dataset, *split_sizes): diff --git a/nablaDFT/optimization/calculator.py b/nablaDFT/optimization/calculator.py index 9a50685..dc64051 100644 --- a/nablaDFT/optimization/calculator.py +++ b/nablaDFT/optimization/calculator.py @@ -13,6 +13,14 @@ class BatchwiseCalculator: """ Base class calculator for neural network models for batchwise optimization. + Args: + model (nn.Module): loaded trained model. + device (str): device used for calculations (default="cpu") + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str): energy units used by model (default="eV") + position_unit (str): position units used by model (default="Angstrom") + dtype (torch.dtype): required data type for the model input (default: torch.float32) """ def __init__( @@ -25,34 +33,6 @@ def __init__( position_unit: str = "Ang", dtype: torch.dtype = torch.float32, ): - """ - model: - path to trained model or trained model - - atoms_converter: - Class used to convert ase Atoms objects to schnetpack input - - device: - device used for calculations (default="cpu") - - auxiliary_output_modules: - auxiliary module to manipulate output properties (e.g., prior energy or forces) - - energy_key: - name of energies in model (default="energy") - - force_key: - name of forces in model (default="forces") - - energy_unit: - energy units used by model (default="eV") - - position_unit: - position units used by model (default="Angstrom") - - dtype: - required data type for the model input (default: torch.float32) - """ self.results = None self.atoms = None @@ -93,11 +73,10 @@ def _requires_calculation(self, property_keys: List[str], atoms: List[ase.Atoms] def get_forces( self, atoms: List[ase.Atoms], fixed_atoms_mask: Optional[List[int]] = None ) -> np.array: - """ - atoms: - - fixed_atoms_mask: - list of indices corresponding to atoms with positions fixed in space. + """Return atom's forces. + Args: + atoms (List[ase.Atoms]): list of ase.Atoms objects. + fixed_atoms_mask (optional, List[int]): list of indices corresponding to atoms with positions fixed in space. """ if self._requires_calculation( property_keys=[self.energy_key, self.force_key], atoms=atoms @@ -118,8 +97,11 @@ def calculate(self, atoms: List[ase.Atoms]) -> None: class PyGBatchwiseCalculator(BatchwiseCalculator): - """Batchwise calculator for PyTorch Geometric models for batchwise optimization""" - + """Batchwise calculator for PyTorch Geometric models for batchwise optimization + Args: + model (nn.Module): loaded PyG model. + """ + def __init__( self, model: nn.Module, @@ -157,7 +139,12 @@ def calculate(self, atoms: List[ase.Atoms]) -> None: class SpkBatchwiseCalculator(BatchwiseCalculator): - """Batchwise calculator for SchNetPack models for batchwise optimization""" + """Batchwise calculator for SchNetPack models for batchwise optimization. + + Args: + model (nn.Module): loaded train schnetpack model. + atoms_converter (AtomsConverter): Class used to convert ase Atoms objects to schnetpack input. + """ def __init__( self, diff --git a/nablaDFT/optimization/opt_utils.py b/nablaDFT/optimization/opt_utils.py index 8fc07ce..7eaa54f 100644 --- a/nablaDFT/optimization/opt_utils.py +++ b/nablaDFT/optimization/opt_utils.py @@ -10,6 +10,11 @@ def np_scatter_add(updates, indices, shape): def atoms_list_to_PYG(ase_atoms_list, device): + """Function to convert ase.Atoms object to PyG data batches. + Args: + ase_atoms_list (List[ase.Atoms]): list of ase.Atoms object to convert. + device (str): task device. + """ data = [] for ase_atoms in ase_atoms_list: z = torch.from_numpy(ase_atoms.numbers).long() diff --git a/nablaDFT/optimization/optimizers.py b/nablaDFT/optimization/optimizers.py index cddc75f..f6577f7 100644 --- a/nablaDFT/optimization/optimizers.py +++ b/nablaDFT/optimization/optimizers.py @@ -305,7 +305,6 @@ class ASEBatchwiseLBFGS(BatchwiseOptimizer): extension/adaptation of the ase.optimize.LBFGS optimizer particularly designed for batch-wise relaxation of atomic structures. The inverse Hessian is approximated for each sample separately, which allows for optimizing batches of different structures/compositions. - """ atoms = None @@ -326,7 +325,7 @@ def __init__( fixed_atoms_mask: Optional[List[int]] = None, verbose: bool = False, ): - """Parameters: + """Args: calculator: This calculator provides properties such as forces and energy, which can be used for MD simulations or diff --git a/nablaDFT/optimization/pyg_ase_interface.py b/nablaDFT/optimization/pyg_ase_interface.py index c3c437e..8cde418 100644 --- a/nablaDFT/optimization/pyg_ase_interface.py +++ b/nablaDFT/optimization/pyg_ase_interface.py @@ -40,6 +40,17 @@ class PYGCalculator(Calculator): """ ASE calculator for pytorch geometric machine learning models. + Args: + model_file (str): path to trained model + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str, float): energy units used by model (default="kcal/mol") + position_unit (str, float): position units used by model (default="Angstrom") + device (torch.device): device used for calculations (default="cpu") + dtype (torch.dtype): select model precision (default=float32) + converter (callable): converter used to set up input batches + additional_inputs (dict): additional inputs required for some transforms in the converter. + **kwargs: Additional arguments for basic ASE calculator class. """ energy = "energy" @@ -58,19 +69,6 @@ def __init__( dtype: torch.dtype = torch.float32, **kwargs, ): - """ - Args: - model_file (str): path to trained model - energy_key (str): name of energies in model (default="energy") - force_key (str): name of forces in model (default="forces") - energy_unit (str, float): energy units used by model (default="kcal/mol") - position_unit (str, float): position units used by model (default="Angstrom") - device (torch.device): device used for calculations (default="cpu") - dtype (torch.dtype): select model precision (default=float32) - converter (callable): converter used to set up input batches - additional_inputs (dict): additional inputs required for some transforms in the converter. - **kwargs: Additional arguments for basic ase calculator class - """ Calculator.__init__(self, **kwargs) self.energy_key = energy_key @@ -124,8 +122,8 @@ def calculate( """ Args: atoms (ase.Atoms): ASE atoms object. - properties (list of str): select properties computed and stored to results. - system_changes (list of str): List of changes for ASE. + properties (List[str]): select properties computed and stored to results. + system_changes (List[str]): List of changes for ASE. """ # First call original calculator to set atoms attribute # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator) @@ -151,6 +149,19 @@ def calculate( class PYGAseInterface: """ Interface for ASE calculations (optimization and molecular dynamics) + + Args: + molecule_path (str): Path to initial geometry + working_dir (str): Path to directory where files should be stored + model_file (str): path to trained model + energy_key (str): name of energies in model (default="energy") + force_key (str): name of forces in model (default="forces") + energy_unit (str, float): energy units used by model (default="kcal/mol") + position_unit (str, float): position units used by model (default="Angstrom") + device (torch.device): device used for calculations (default="cpu") + dtype (torch.dtype): select model precision (default=float32) + optimizer_class (ase.optimize.optimizer): ASE optimizer used for structure relaxation. + fixed_atoms (list(int)): list of indices corresponding to atoms with positions fixed in space. """ def __init__( @@ -168,21 +179,6 @@ def __init__( optimizer_class: type = QuasiNewton, fixed_atoms: Optional[List[int]] = None, ): - """ - Args: - molecule_path: Path to initial geometry - working_dir: Path to directory where files should be stored - model_file (str): path to trained model - energy_key (str): name of energies in model (default="energy") - force_key (str): name of forces in model (default="forces") - energy_unit (str, float): energy units used by model (default="kcal/mol") - position_unit (str, float): position units used by model (default="Angstrom") - device (torch.device): device used for calculations (default="cpu") - dtype (torch.dtype): select model precision (default=float32) - optimizer_class (ase.optimize.optimizer): ASE optimizer used for structure relaxation. - fixed_atoms (list(int)): list of indices corresponding to atoms with positions fixed in space. - - """ # Setup directory self.working_dir = working_dir if not os.path.exists(self.working_dir): diff --git a/nablaDFT/optimization/task.py b/nablaDFT/optimization/task.py index 6bf64f5..5c9831b 100644 --- a/nablaDFT/optimization/task.py +++ b/nablaDFT/optimization/task.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List import tqdm from ase.db import connect @@ -7,7 +7,15 @@ class BatchwiseOptimizeTask: - """Use for batchwise molecules conformations geometry optimization.""" + """Use for batchwise molecules conformations geometry optimization. + + Args: + input_datapath (str): path to ASE database with molecules. + output_datapath (str): path to output database. + optimizer (BatchwiseOptimizer): used for molecule geometry optimization. + converter (AtomsConverter): optional, mandatory for SchNetPack models. + batch_size (int): number of samples per batch. + """ def __init__( self, @@ -16,14 +24,6 @@ def __init__( optimizer: BatchwiseOptimizer, batch_size: int, ) -> None: - """ - Args: - input_datapath (str): path to ASE database with molecules. - output_datapath (str): path to output database. - optimizer (BatchwiseOptimizer): used for molecule geometry optimization. - converter (AtomsConverter): optional, mandatory for SchNetPack models. - batch_size (int): number of samples per batch. - """ self.optimizer = optimizer self.bs = batch_size self.data_db_conn = None diff --git a/nablaDFT/pipelines.py b/nablaDFT/pipelines.py index 21a6aed..36d5dbb 100644 --- a/nablaDFT/pipelines.py +++ b/nablaDFT/pipelines.py @@ -30,6 +30,10 @@ def predict( ): """Function for prediction loop execution. Saves model prediction to "predictions" directory. + + Args: + ckpt_path (str): path to model checkpoint. + config (DictConfig): config for task. see r'config/' for examples. """ trainer.logger = False # need this to disable save_hyperparameters during predict, otherwise OmegaConf DictConf can't be dumped to YAML pred_path = os.path.join(os.getcwd(), "predictions") @@ -43,6 +47,10 @@ def predict( def optimize(config: DictConfig, ckpt_path: str): """Function for batched molecules optimization. Uses model defined in config. + + Args: + ckpt_path (str): path to model checkpoint. + config (DictConfig): config for task. see r'config/' for examples. """ output_dir = config.get("output_dir") if not os.path.exists(output_dir): @@ -66,6 +74,12 @@ def optimize(config: DictConfig, ckpt_path: str): def run(config: DictConfig): + """Main function to perform task runs on nablaDFT datasets. + Refer to r'nablaDFT/README.md' for detailed description of run configuration. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + """ if config.get("seed"): seed_everything(config.seed) job_type = config.get("job_type") diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index 297b624..35717ba 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -27,7 +27,11 @@ logger = logging.getLogger() def get_file_size(url: str) -> int: - """Returns file size in bytes""" + """Returns file size in bytes + + Args: + url (str): url of file to download + """ req = request.Request(url, method="HEAD") with request.urlopen(req) as f: file_size = f.headers.get('Content-Length') @@ -87,7 +91,11 @@ def init_wandb(): def download_model(config: DictConfig) -> str: - """Downloads best model checkpoint from vault.""" + """Downloads best model checkpoint from vault. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + """ model_name = config.get("name") ckpt_path = os.path.join( hydra.utils.get_original_cwd(), @@ -108,6 +116,12 @@ def download_model(config: DictConfig) -> str: def load_model(config: DictConfig, ckpt_path: str) -> LightningModule: + """Instantiates model and loads model weights from checkpoint. + + Args: + config (DictConfig): config for task. see r'config/' for examples. + ckpt_path (str): path to checkpoint. + """ model: LightningModule = hydra.utils.instantiate(config.model) if ckpt_path is None: warnings.warn( From ffb3283c95a5f4fc32e82010e6f0085a814dbccb Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 23 Apr 2024 13:03:37 +0300 Subject: [PATCH 48/57] revert some changes --- nablaDFT/dataset/nablaDFT_dataset.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index c771470..9f9999c 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -139,6 +139,16 @@ def __init__( self.dataset_name = dataset_name self.seed = seed self.sizes = [train_size, val_size] + dataloader_keys = [ + "batch_size", "num_workers", + "pin_memory", "persistent_workers" + ] + self.dataloader_kwargs = {} + for key in dataloader_keys: + val = kwargs.get(key, None) + self.dataloader_kwargs[key] = val + if val is not None: + del kwargs[key] self.kwargs = kwargs def dataloader(self, dataset, **kwargs): From 4c3f6e9773393fb1241a39ffabe0b845acce5906 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 23 Apr 2024 16:20:38 +0300 Subject: [PATCH 49/57] fixed test hamiltonian datasets download --- nablaDFT/dataset/pyg_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index cc24998..892bf6a 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -207,7 +207,7 @@ def get(self, idx): def download(self) -> None: with open(nablaDFT.__path__[0] + "/links/hamiltonian_databases.json") as f: data = json.load(f) - url = data["train_databases"][self.dataset_name] + url = data[f"{self.split}_databases"][self.dataset_name] file_size = get_file_size(url) with tqdm( unit="B", From 922842dbbace7e4db67676e5d460a977631322d2 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Wed, 24 Apr 2024 11:51:11 +0300 Subject: [PATCH 50/57] change graphormer scheduler instantiation --- config/model/graphormer3d-base.yaml | 2 +- config/model/graphormer3d-large.yaml | 2 +- nablaDFT/__init__.py | 2 +- nablaDFT/graphormer/graphormer_3d.py | 55 ++-------------------------- 4 files changed, 6 insertions(+), 55 deletions(-) diff --git a/config/model/graphormer3d-base.yaml b/config/model/graphormer3d-base.yaml index 40dd41b..dbf9cb9 100644 --- a/config/model/graphormer3d-base.yaml +++ b/config/model/graphormer3d-base.yaml @@ -20,7 +20,7 @@ optimizer: lr: 3e-4 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} diff --git a/config/model/graphormer3d-large.yaml b/config/model/graphormer3d-large.yaml index cc989cc..f7bf26a 100644 --- a/config/model/graphormer3d-large.yaml +++ b/config/model/graphormer3d-large.yaml @@ -22,7 +22,7 @@ optimizer: weight_decay: 1e-3 lr_scheduler: - _target_: nablaDFT.graphormer.schedulers.get_linear_schedule_with_warmup + _target_: nablaDFT.schedulers.get_linear_schedule_with_warmup _partial_: true num_warmup_steps: ${warmup_steps} num_training_steps: ${max_steps} diff --git a/nablaDFT/__init__.py b/nablaDFT/__init__.py index 1dae3d5..c23f60a 100644 --- a/nablaDFT/__init__.py +++ b/nablaDFT/__init__.py @@ -8,4 +8,4 @@ from . import ase_model from . import painn_pyg -from .schedulers import get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup +from . import schedulers diff --git a/nablaDFT/graphormer/graphormer_3d.py b/nablaDFT/graphormer/graphormer_3d.py index 2ec20c4..cf7c4f3 100644 --- a/nablaDFT/graphormer/graphormer_3d.py +++ b/nablaDFT/graphormer/graphormer_3d.py @@ -12,8 +12,8 @@ @torch.jit.script -def softmax_dropout(input, dropout_prob: float, is_training: bool): - return F.dropout(F.softmax(input, -1), dropout_prob, is_training) +def softmax_dropout(x, dropout_prob: float, is_training: bool): + return F.dropout(F.softmax(x, -1), dropout_prob, is_training) class SelfMultiheadAttention(nn.Module): @@ -380,9 +380,7 @@ def step( y = batch.y energy_out, forces_out, mask_out = self(batch) loss_energy = self.loss(energy_out, y) - forces, mask_forces = to_dense_batch( - batch.forces, batch.batch, batch_size=bsz - ) + forces, mask_forces = to_dense_batch(batch.forces, batch.batch, batch_size=bsz) masked_forces_out = forces_out * mask_forces.unsqueeze(-1) loss_forces = self.loss(masked_forces_out, forces) loss = self.loss_forces_coef * loss_forces + self.loss_energy_coef * loss_energy @@ -511,50 +509,3 @@ def _get_batch_size(self, batch): """Function for batch size infer.""" bsz = batch.batch.max().detach().item() + 1 # get batch size return bsz - - -from functools import partial -from torch.optim.lr_scheduler import LambdaLR - - -def get_linear_schedule_with_warmup( - optimizer, num_warmup_steps, num_training_steps, last_epoch=-1 -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L104 - """ - Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after - a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. - - Args: - optimizer ([`~torch.optim.Optimizer`]): - The optimizer for which to schedule the learning rate. - num_warmup_steps (`int`): - The number of steps for the warmup phase. - num_training_steps (`int`): - The total number of training steps. - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - - Return: - `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. - """ - - lr_lambda = partial( - _get_linear_schedule_with_warmup_lr_lambda, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - ) - return LambdaLR(optimizer, lr_lambda, last_epoch) - - -def _get_linear_schedule_with_warmup_lr_lambda( - current_step: int, *, num_warmup_steps: int, num_training_steps: int -): - # from https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/optimization.py#L98 - if current_step < num_warmup_steps: - return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, - float(num_training_steps - current_step) - / float(max(1, num_training_steps - num_warmup_steps)), - ) From 14cbd60a54362985316f5e97a704b37d803f54d8 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 2 May 2024 15:38:00 +0300 Subject: [PATCH 51/57] updated qhnet ckpts links --- nablaDFT/links/models_checkpoints.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 890bf72..632798f 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -55,5 +55,11 @@ "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_10k_epoch_1095_val_loss_0.010159.ckpt", "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_5k_epoch_1331_val_loss_0.012179.ckpt", "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" + }, + "QHNet": { + "dataset_train_100k": null, + "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-10k_dataset_train_10k_epoch=034_val_loss=0.001594.ckpt", + "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-5k_dataset_train_5k_epoch=067_val_loss=0.002631.ckpt", + "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-2k_dataset_train_2k_epoch=055_val_loss=0.002044.ckpt" } } \ No newline at end of file From 97b4caecee2a293da02f47d503dcc976bba1e512 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Thu, 2 May 2024 16:51:45 +0300 Subject: [PATCH 52/57] fix hamiltonian batch --- nablaDFT/dataset/nablaDFT_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index 9f9999c..b905be8 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -158,16 +158,16 @@ def setup(self, stage: str) -> None: raise NotImplementedError def train_dataloader(self): - return self.dataloader(self.dataset_train, shuffle=True, **self.kwargs) + return self.dataloader(self.dataset_train, shuffle=True, **self.dataloader_kwargs) def val_dataloader(self): - return self.dataloader(self.dataset_val, shuffle=False, **self.kwargs) + return self.dataloader(self.dataset_val, shuffle=False, **self.dataloader_kwargs) def test_dataloader(self): - return self.dataloader(self.dataset_test, shuffle=False, **self.kwargs) + return self.dataloader(self.dataset_test, shuffle=False, **self.dataloader_kwargs) def predict_dataloader(self): - return self.dataloader(self.dataset_predict, shuffle=False, **self.kwargs) + return self.dataloader(self.dataset_predict, shuffle=False, **self.dataloader_kwargs) class PyGHamiltonianDataModule(PyGDataModule): From 995e806d219e5a80308ca7800cb78bc4f1191ac6 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 7 May 2024 16:26:24 +0300 Subject: [PATCH 53/57] qhnet 100k link --- nablaDFT/links/models_checkpoints.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 632798f..28311d7 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -57,7 +57,7 @@ "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" }, "QHNet": { - "dataset_train_100k": null, + "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-100k_dataset_train_100k_epoch=017_val_loss=0.002419.ckpt", "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-10k_dataset_train_10k_epoch=034_val_loss=0.001594.ckpt", "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-5k_dataset_train_5k_epoch=067_val_loss=0.002631.ckpt", "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-2k_dataset_train_2k_epoch=055_val_loss=0.002044.ckpt" From b63d6aa16fd7ff063e7f56bd1fdc0ac868eb23f8 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Tue, 7 May 2024 19:48:23 +0300 Subject: [PATCH 54/57] fixed qhnet metrics --- config/model/qhnet.yaml | 2 +- nablaDFT/qhnet/__init__.py | 3 ++- nablaDFT/qhnet/loss.py | 1 - nablaDFT/qhnet/masked_mae.py | 19 +++++++++++++++++++ nablaDFT/qhnet/qhnet.py | 10 +++------- 5 files changed, 25 insertions(+), 10 deletions(-) create mode 100644 nablaDFT/qhnet/masked_mae.py diff --git a/config/model/qhnet.yaml b/config/model/qhnet.yaml index f630b1f..0ef7ed4 100644 --- a/config/model/qhnet.yaml +++ b/config/model/qhnet.yaml @@ -46,7 +46,7 @@ metric: _convert_: all task_metrics: hamiltonian: - _target_: torchmetrics.MeanAbsoluteError + _target_: nablaDFT.qhnet.MaskedMeanAbsoluteError ema: _target_: torch_ema.ExponentialMovingAverage diff --git a/nablaDFT/qhnet/__init__.py b/nablaDFT/qhnet/__init__.py index d38b468..d4ead39 100644 --- a/nablaDFT/qhnet/__init__.py +++ b/nablaDFT/qhnet/__init__.py @@ -1 +1,2 @@ -from .qhnet import QHNet, QHNetLightning \ No newline at end of file +from .qhnet import QHNet, QHNetLightning +from . masked_mae import MaskedMeanAbsoluteError diff --git a/nablaDFT/qhnet/loss.py b/nablaDFT/qhnet/loss.py index d4d87a5..cdfbab9 100644 --- a/nablaDFT/qhnet/loss.py +++ b/nablaDFT/qhnet/loss.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F class HamiltonianLoss(nn.Module): diff --git a/nablaDFT/qhnet/masked_mae.py b/nablaDFT/qhnet/masked_mae.py new file mode 100644 index 0000000..0e2d345 --- /dev/null +++ b/nablaDFT/qhnet/masked_mae.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor + +from torchmetrics.functional.regression.mae import _mean_absolute_error_update +from torchmetrics.regression import MeanAbsoluteError + + +class MaskedMeanAbsoluteError(MeanAbsoluteError): + """Overloaded MAE for usage with block diagonal matrix. + Mask calculated from target tensor.""" + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + Both inputs are block diagonal Torch Tensors.""" + sum_abs_error, _ = _mean_absolute_error_update(preds, target) + num_obs = torch.count_nonzero(target).item() + + self.sum_abs_error += sum_abs_error + self.total += num_obs diff --git a/nablaDFT/qhnet/qhnet.py b/nablaDFT/qhnet/qhnet.py index 3652f6a..348a6f6 100644 --- a/nablaDFT/qhnet/qhnet.py +++ b/nablaDFT/qhnet/qhnet.py @@ -305,7 +305,7 @@ def step(self, batch, calculate_metrics: bool = False): target = {'hamiltonian': hamiltonian} loss = self._calculate_loss(preds, target, masks) if calculate_metrics: - metrics = self._calculate_metrics(preds, target, masks) + metrics = self._calculate_metrics(preds, target) return loss, metrics return loss @@ -352,8 +352,7 @@ def validation_step(self, batch, batch_idx): def test_step(self, batch, batch_idx): bsz = self._get_batch_size(batch) - with self.ema.average_parameters(): - loss, metrics = self.step(batch, calculate_metrics=True) + loss, metrics = self.step(batch, calculate_metrics=True) self.log( "test/loss", loss, @@ -424,12 +423,9 @@ def _calculate_loss(self, y_pred, y_true, masks) -> float: ) return total_loss - def _calculate_metrics(self, y_pred, y_true, mask) -> Dict: + def _calculate_metrics(self, y_pred, y_true) -> Dict: """Function for metrics calculation during step.""" - # TODO: temp workaround for metric normalization by mask sum - norm_coef = (y_pred['hamiltonian'].numel() / mask.sum()) metric = self.hparams.metric(y_pred, y_true) - metric['hamiltonian'] = metric['hamiltonian'] * norm_coef return metric def _log_current_lr(self) -> None: From cbf5aa725d095149412875209a569b2b071dbe49 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 13 May 2024 15:24:09 +0300 Subject: [PATCH 55/57] dataset names updated from 2k/5k/10/100k to tiny/small/medium/large --- config/dimenetplusplus.yaml | 2 +- config/equiformer_v2_oc20.yaml | 2 +- config/escn-oc.yaml | 2 +- config/gemnet-oc.yaml | 2 +- config/gemnet-oc_test.yaml | 2 +- config/graphormer3d.yaml | 2 +- config/painn-oc.yaml | 25 +++++++ config/painn.yaml | 2 +- config/qhnet.yaml | 2 +- config/schnet.yaml | 2 +- config/schnet_test.yaml | 2 +- nablaDFT/dataset/nablaDFT_dataset.py | 2 +- nablaDFT/dataset/pyg_datasets.py | 4 +- nablaDFT/links/energy_databases.json | 16 ++--- nablaDFT/links/hamiltonian_databases.json | 16 ++--- nablaDFT/links/models_checkpoints.json | 80 +++++++++++------------ nablaDFT/pipelines.py | 3 +- nablaDFT/utils.py | 2 + 18 files changed, 98 insertions(+), 70 deletions(-) create mode 100644 config/painn-oc.yaml diff --git a/config/dimenetplusplus.yaml b/config/dimenetplusplus.yaml index 930cf80..6ad943f 100644 --- a/config/dimenetplusplus.yaml +++ b/config/dimenetplusplus.yaml @@ -1,6 +1,6 @@ # Global variables name: DimeNet++ -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/equiformer_v2_oc20.yaml b/config/equiformer_v2_oc20.yaml index 547daa9..5c7ed16 100644 --- a/config/equiformer_v2_oc20.yaml +++ b/config/equiformer_v2_oc20.yaml @@ -1,6 +1,6 @@ # Global variables name: Equiformer_v2 -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/escn-oc.yaml b/config/escn-oc.yaml index 14c6933..c975d16 100644 --- a/config/escn-oc.yaml +++ b/config/escn-oc.yaml @@ -1,6 +1,6 @@ # Global variables name: ESCN-OC -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/gemnet-oc.yaml b/config/gemnet-oc.yaml index e4fbf52..0b900f5 100644 --- a/config/gemnet-oc.yaml +++ b/config/gemnet-oc.yaml @@ -1,6 +1,6 @@ # Global variables name: GemNet-OC -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/gemnet-oc_test.yaml b/config/gemnet-oc_test.yaml index 29ef46a..0cac57e 100644 --- a/config/gemnet-oc_test.yaml +++ b/config/gemnet-oc_test.yaml @@ -1,6 +1,6 @@ # Global variables name: GemNet-OC -dataset_name: dataset_test_conformations_2k +dataset_name: dataset_test_conformations_tiny max_steps: 1000000 warmup_steps: 0 job_type: test diff --git a/config/graphormer3d.yaml b/config/graphormer3d.yaml index a7b4d42..01bba18 100644 --- a/config/graphormer3d.yaml +++ b/config/graphormer3d.yaml @@ -1,6 +1,6 @@ # Global variables name: Graphormer3D-small -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 60000 job_type: train diff --git a/config/painn-oc.yaml b/config/painn-oc.yaml new file mode 100644 index 0000000..2193b1e --- /dev/null +++ b/config/painn-oc.yaml @@ -0,0 +1,25 @@ +# Global variables +name: PaiNN-OC-MSE-energy +dataset_name: dataset_test_conformations_2k +max_steps: 1000000 +job_type: test +pretrained: False +ckpt_path: ./checkpoints/PaiNN-OC-MSE-energy/PaiNN-OC-MSE-energy_dataset_train_2k_epoch=2037_val_loss=0.002064.ckpt # path to checkpoint for training resume or test run + +# configs +defaults: + - _self_ + - datamodule: nablaDFT_pyg_test.yaml # dataset config + - model: painn-oc.yaml # model config + - callbacks: default.yaml # pl callbacks config + - loggers: wandb.yaml # pl loggers config + - trainer: test.yaml # trainer config + +# need this to set working dir as current dir +hydra: + output_subdir: null + run: + dir: . +original_work_dir: ${hydra:runtime.cwd} + +seed: 23 \ No newline at end of file diff --git a/config/painn.yaml b/config/painn.yaml index 47ce455..9a10dcb 100644 --- a/config/painn.yaml +++ b/config/painn.yaml @@ -1,6 +1,6 @@ # Global variables name: PaiNN -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/qhnet.yaml b/config/qhnet.yaml index 495be9b..55a4b78 100644 --- a/config/qhnet.yaml +++ b/config/qhnet.yaml @@ -1,6 +1,6 @@ # Global variables name: QHNet -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 warmup_steps: 0 job_type: train diff --git a/config/schnet.yaml b/config/schnet.yaml index 50205e0..a28b5a2 100644 --- a/config/schnet.yaml +++ b/config/schnet.yaml @@ -1,6 +1,6 @@ # Global variables name: SchNet -dataset_name: dataset_train_2k +dataset_name: dataset_train_tiny max_steps: 1000000 job_type: train pretrained: False diff --git a/config/schnet_test.yaml b/config/schnet_test.yaml index a804365..79ef52e 100644 --- a/config/schnet_test.yaml +++ b/config/schnet_test.yaml @@ -1,6 +1,6 @@ # Global variables name: SchNet -dataset_name: dataset_test_conformations_2k +dataset_name: dataset_test_conformations_tiny max_steps: 1000000 job_type: test pretrained: False diff --git a/nablaDFT/dataset/nablaDFT_dataset.py b/nablaDFT/dataset/nablaDFT_dataset.py index b905be8..c84044d 100644 --- a/nablaDFT/dataset/nablaDFT_dataset.py +++ b/nablaDFT/dataset/nablaDFT_dataset.py @@ -36,7 +36,7 @@ class ASENablaDFT(AtomsDataModule): def __init__( self, split: str, - dataset_name: str = "dataset_train_2k", + dataset_name: str = "dataset_train_tiny", datapath: str = "database", data_workdir: Optional[str] = "logs", batch_size: int = 2000, diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index 892bf6a..c7019f3 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -44,7 +44,7 @@ def processed_file_names(self) -> str: def __init__( self, datapath: str = "database", - dataset_name: str = "dataset_train_2k", + dataset_name: str = "dataset_train_tiny", split: str = "train", transform: Callable = None, pre_transform: Callable = None, @@ -148,7 +148,7 @@ def processed_file_names(self) -> str: def __init__( self, datapath: str = "database", - dataset_name: str = "dataset_train_2k", + dataset_name: str = "dataset_train_tiny", split: str = "train", include_hamiltonian: bool = True, include_overlap: bool = False, diff --git a/nablaDFT/links/energy_databases.json b/nablaDFT/links/energy_databases.json index 6657d5f..0070cdd 100644 --- a/nablaDFT/links/energy_databases.json +++ b/nablaDFT/links/energy_databases.json @@ -1,17 +1,17 @@ { "train_databases": { "dataset_train_full": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_full_v2_formation_energy.db", - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_100k_v2_formation_energy_w_forces.db", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_10k_v2_formation_energy_w_forces.db", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_5k_v2_formation_energy_w_forces.db", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_2k_v2_formation_energy_w_forces.db" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_100k_v2_formation_energy_w_forces.db", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_10k_v2_formation_energy_w_forces.db", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_5k_v2_formation_energy_w_forces.db", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/train_2k_v2_formation_energy_w_forces.db" }, "test_databases": { "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_structures_v2_formation_energy_forces.db", "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_full_scaffolds_v2_formation_energy_forces.db", - "dataset_test_conformations_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_100k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_10k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_5k_conformers_v2_formation_energy_w_forces.db", - "dataset_test_conformations_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_2k_conformers_v2_formation_energy_w_forces.db" + "dataset_test_conformations_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_100k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_10k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_5k_conformers_v2_formation_energy_w_forces.db", + "dataset_test_conformations_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/energy_databases/test_2k_conformers_v2_formation_energy_w_forces.db" } } diff --git a/nablaDFT/links/hamiltonian_databases.json b/nablaDFT/links/hamiltonian_databases.json index 3c5fc90..2646558 100644 --- a/nablaDFT/links/hamiltonian_databases.json +++ b/nablaDFT/links/hamiltonian_databases.json @@ -1,16 +1,16 @@ { "train_databases": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_100k.db", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_10k.db", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_5k.db", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_2k.db" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_100k.db", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_10k.db", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_5k.db", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/train_2k.db" }, "test_databases": { "dataset_test_structures": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_structures.db", "dataset_test_scaffolds": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_scaffolds.db", - "dataset_test_conformations_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_100k_conformers.db", - "dataset_test_conformations_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_10k_conformers.db", - "dataset_test_conformations_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_5k_conformers.db", - "dataset_test_conformations_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_2k_conformers.db" + "dataset_test_conformations_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_100k_conformers.db", + "dataset_test_conformations_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_10k_conformers.db", + "dataset_test_conformations_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_5k_conformers.db", + "dataset_test_conformations_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/hamiltonian_databases/test_2k_conformers.db" } } \ No newline at end of file diff --git a/nablaDFT/links/models_checkpoints.json b/nablaDFT/links/models_checkpoints.json index 28311d7..44c33d1 100644 --- a/nablaDFT/links/models_checkpoints.json +++ b/nablaDFT/links/models_checkpoints.json @@ -1,65 +1,65 @@ { "SchNet": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_100k.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_10k.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_5k.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_2k.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_100k.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_10k.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/SchNet/schnet_2k.ckpt" }, "PaiNN": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_100k.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_10k.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_5k.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_2k.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_100k.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_10k.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/PaiNN/painn_2k.ckpt" }, "DimeNet++": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_100k_epoch=0258.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_10k_epoch=0651.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_5k_epoch=0442.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_2k_epoch=0545.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_100k_epoch=0258.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_10k_epoch=0651.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_5k_epoch=0442.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/DimeNet%2b%2b/DimeNet%2b%2b_dataset_train_2k_epoch=0545.ckpt" }, "PhiSNet": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_100k_split.pt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_10k_split.pt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_5k_split.pt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_2k_split.pt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_100k_split.pt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_10k_split.pt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_5k_split.pt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/phisnet/model_2k_split.pt" }, "SchNOrb": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_100k_split.pt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_10k_split.pt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_5k_split.pt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_2k_split.pt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_100k_split.pt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_10k_split.pt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_5k_split.pt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/schnorb/model_2k_split.pt" }, "GemNet-OC": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_100k_epoch=085_val_loss=0.028551.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_10k_epoch=337_val_loss=0.060103.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_5k_epoch=427_val_loss=0.078626.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_2k_epoch=400_val_loss=0.142715.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_100k_epoch=085_val_loss=0.028551.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_10k_epoch=337_val_loss=0.060103.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_5k_epoch=427_val_loss=0.078626.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/GemNet-OC/GemNet-OC_dataset_train_2k_epoch=400_val_loss=0.142715.ckpt" }, "ESCN-OC": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_100k_epoch=028_val_loss=0.029561.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_10k_epoch=189_val_loss=0.052877.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_5k_epoch=175_val_loss=0.065441.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_2k_epoch=253_val_loss=0.124682.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_100k_epoch=028_val_loss=0.029561.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_10k_epoch=189_val_loss=0.052877.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_5k_epoch=175_val_loss=0.065441.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/ESCN-OC/ESCN-OC_dataset_train_2k_epoch=253_val_loss=0.124682.ckpt" }, "Equiformer_v2": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_100k_epoch=010_val_loss=0.302613.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_10k_epoch=107_val_loss=0.337899.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_2k_epoch=409_val_loss=0.354093.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_100k_epoch=010_val_loss=0.302613.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_10k_epoch=107_val_loss=0.337899.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_5k.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/Equiformer_OC20/Equiformer_OC20_dataset_train_2k_epoch=409_val_loss=0.354093.ckpt" }, "Graphormer3D-small": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_100k_epoch_420_val_loss_0.005773.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_10k_epoch_1095_val_loss_0.010159.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_5k_epoch_1331_val_loss_0.012179.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_100k_epoch_420_val_loss_0.005773.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_10k_epoch_1095_val_loss_0.010159.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_5k_epoch_1331_val_loss_0.012179.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/graphormer/small/Graphormer3D-small_dataset_train_2k_epoch_672_val_loss_0.018476.ckpt" }, "QHNet": { - "dataset_train_100k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-100k_dataset_train_100k_epoch=017_val_loss=0.002419.ckpt", - "dataset_train_10k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-10k_dataset_train_10k_epoch=034_val_loss=0.001594.ckpt", - "dataset_train_5k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-5k_dataset_train_5k_epoch=067_val_loss=0.002631.ckpt", - "dataset_train_2k": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-2k_dataset_train_2k_epoch=055_val_loss=0.002044.ckpt" + "dataset_train_large": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-100k_dataset_train_100k_epoch=017_val_loss=0.002419.ckpt", + "dataset_train_medium": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-10k_dataset_train_10k_epoch=034_val_loss=0.001594.ckpt", + "dataset_train_small": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-5k_dataset_train_5k_epoch=067_val_loss=0.002631.ckpt", + "dataset_train_tiny": "https://a002dlils-kadurin-nabladft.obs.ru-moscow-1.hc.sbercloud.ru/data/nablaDFTv2/models_checkpoints/QHNet/QHNet-2k_dataset_train_2k_epoch=055_val_loss=0.002044.ckpt" } } \ No newline at end of file diff --git a/nablaDFT/pipelines.py b/nablaDFT/pipelines.py index 36d5dbb..03f5cd0 100644 --- a/nablaDFT/pipelines.py +++ b/nablaDFT/pipelines.py @@ -35,7 +35,8 @@ def predict( ckpt_path (str): path to model checkpoint. config (DictConfig): config for task. see r'config/' for examples. """ - trainer.logger = False # need this to disable save_hyperparameters during predict, otherwise OmegaConf DictConf can't be dumped to YAML + trainer.logger = False # need this to disable save_hyperparameters during predict, + # otherwise OmegaConf DictConf can't be dumped to YAML pred_path = os.path.join(os.getcwd(), "predictions") os.makedirs(pred_path, exist_ok=True) predictions = trainer.predict( diff --git a/nablaDFT/utils.py b/nablaDFT/utils.py index 35717ba..a53f96a 100644 --- a/nablaDFT/utils.py +++ b/nablaDFT/utils.py @@ -26,6 +26,7 @@ logger = logging.getLogger() + def get_file_size(url: str) -> int: """Returns file size in bytes @@ -37,6 +38,7 @@ def get_file_size(url: str) -> int: file_size = f.headers.get('Content-Length') return int(file_size) + def tqdm_download_hook(t): """wraps TQDM progress bar instance""" last_block = [0] From 7a64b0938650cc565bc904c9e6751b17a0d6fb22 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 13 May 2024 15:26:30 +0300 Subject: [PATCH 56/57] painn-oc cfg to default --- config/painn-oc.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/painn-oc.yaml b/config/painn-oc.yaml index 2193b1e..0df1706 100644 --- a/config/painn-oc.yaml +++ b/config/painn-oc.yaml @@ -1,19 +1,19 @@ # Global variables -name: PaiNN-OC-MSE-energy -dataset_name: dataset_test_conformations_2k +name: PaiNN-OC +dataset_name: dataset_train_tiny max_steps: 1000000 -job_type: test +job_type: train pretrained: False -ckpt_path: ./checkpoints/PaiNN-OC-MSE-energy/PaiNN-OC-MSE-energy_dataset_train_2k_epoch=2037_val_loss=0.002064.ckpt # path to checkpoint for training resume or test run +ckpt_path: null # path to checkpoint for training resume or test run # configs defaults: - _self_ - - datamodule: nablaDFT_pyg_test.yaml # dataset config + - datamodule: nablaDFT_pyg.yaml # dataset config - model: painn-oc.yaml # model config - callbacks: default.yaml # pl callbacks config - loggers: wandb.yaml # pl loggers config - - trainer: test.yaml # trainer config + - trainer: train.yaml # trainer config # need this to set working dir as current dir hydra: From 3a69aa203a23a73b5422fc72d1fcb6245e201678 Mon Sep 17 00:00:00 2001 From: BerAnton Date: Mon, 13 May 2024 15:53:40 +0300 Subject: [PATCH 57/57] suppres warning during dataset preprocessing --- nablaDFT/dataset/pyg_datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nablaDFT/dataset/pyg_datasets.py b/nablaDFT/dataset/pyg_datasets.py index c7019f3..0932bf6 100644 --- a/nablaDFT/dataset/pyg_datasets.py +++ b/nablaDFT/dataset/pyg_datasets.py @@ -98,10 +98,10 @@ def process(self) -> None: db = connect(self.raw_paths[0]) samples = [] for db_row in tqdm(db.select(), total=len(db)): - z = torch.from_numpy(db_row.numbers).long() - positions = torch.from_numpy(db_row.positions).float() - y = torch.from_numpy(np.array(db_row.data["energy"])).float() - forces = torch.from_numpy(np.array(db_row.data["forces"])).float() + z = torch.from_numpy(db_row.numbers.copy()).long() + positions = torch.from_numpy(db_row.positions.copy()).float() + y = torch.from_numpy(np.array(db_row.data["energy"]).copy()).float() + forces = torch.from_numpy(np.array(db_row.data["forces"]).copy()).float() samples.append(Data(z=z, pos=positions, y=y, forces=forces)) if self.pre_filter is not None: