# MiSaTo-Dataset: a tutorial

In this notebook, we will show how our QM and MD dataset are stored in h5 files. We also show how the data can be loaded so that it can be used by a deep learning model. 

We start by importing the useful packages and set up the paths of the files

In [1]:
import h5py
import numpy as np 

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

from data.components.datasets import MolDataset, ProtDataset
from data.components.transformQM import GNNTransformQM
from data.components.transformMD import GNNTransformMD
from data.qm_datamodule import QMDataModule
from data.md_datamodule import MDDataModule
from data.processing import preprocessing_db



In [2]:
qmh5_file = "../data/QM/h5_files/tiny_qm.hdf5"
norm_file = "../data/QM/h5_files/qm_norm.hdf5"
norm_txtfile = "../data/QM/splits/train_norm.txt"

## H5 files presentations

We read the QM H5 file and H5 file used to normalize the target values.

In [3]:
qm_H5File = h5py.File(qmh5_file)
qm_normFile = h5py.File(norm_file)

The ligands can be accessed using the pdb-id. Bellow we show the first ten molecules of the file.

In [4]:
qm_H5File.keys()

<KeysViewHDF5 ['10gs', '11gs', '13gs', '16pk', '184l', '185l', '186l', '187l', '188l', '1a07', '1a08', '1a09', '1a0q', '1a0tA', '1a0tB', '1a1b', '1a1c', '1a1e', '1a28', '1a2c', '1a30', '1a37', '1a42', '1a46', '1a4g', '1a4h', '1a4k', '1a4m', '1a4q', '1a4r', '1a4w', '1a50', '1a52', '1a5g', '1a5h', '1a5v', '1a61', '1a69', '1a7c', '1a7t', '1a7x', '1a85', '1a86', '1a8i', '1a8t', '1a94', '1a99', '1a9m', '1a9q', '1a9u']>

You can access to the molecule trajectories as follow

In [7]:
xyz = qm_H5File["10gs"]["atom_properties"]["atom_properties_values"][:, 0:3]

Target values can be accessed by specifiying into bracket the molecule name, then mol_properties and finally the name of the target value that we want to access: 

In [8]:
qm_H5File["10gs"]["mol_properties"]["Electron_Affinity"][()]

6.0974

We can access to the mean and standard-deviation of each target value by specifiying it into bracket.
We first specify the set, then the target value and finally either mean or std. 

In [9]:
qm_normFile.keys()

<KeysViewHDF5 ['Electron_Affinity', 'Electronegativity', 'Hardness', 'Ionization_Potential']>

In [10]:
print(qm_normFile["Electron_Affinity"]["mean"][()])
print(qm_normFile["Electron_Affinity"]["std"][()])

6.33265
18.636927


## Datasets and dataloaders

### PyTorch

The QM and MD datasets are warped into a PyTorch Dataset class under the name MolDataset and ProtDataset, respectively. 
The parameters taken by the two classes as well as their types can be found as follow.

In [11]:
help(MolDataset)

Help on class MolDataset in module data.components.datasets:

class MolDataset(torch.utils.data.dataset.Dataset)
 |  MolDataset(data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |  
 |  Load the QM dataset.
 |  
 |  Method resolution order:
 |      MolDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, data_file, idx_file, target_norm_file, transform, isTrain=False, post_transform=None)
 |      Args:
 |          data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          target_norm_file (str): H5 file path where training mean and std are stored.  
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          isTrain (bool, optional): Flag to standardize the target values (only used

In [12]:
help(ProtDataset)

Help on class ProtDataset in module data.components.datasets:

class ProtDataset(torch.utils.data.dataset.Dataset)
 |  ProtDataset(md_data_file, idx_file, transform=None, post_transform=None)
 |  
 |  Load the MD dataset
 |  
 |  Method resolution order:
 |      ProtDataset
 |      torch.utils.data.dataset.Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __getitem__(self, index: int)
 |  
 |  __init__(self, md_data_file, idx_file, transform=None, post_transform=None)
 |      Args:
 |          md_data_file (str): H5 file path
 |          idx_file (str): path of txt file which contains pdb ids for a specific split such as train, val or test.
 |          transform (obj): class that convert a dict to a PyTorch Geometric graph.
 |          post_transform (PyTorch Geometric, optional): data augmentation. Defaults to None.
 |  
 |  __len__(self) -> int
 |  
 |  ----------------------------------------------------------------------
 |  Data and oth

We can load the data by instanciating MolDataset and providing the QM H5 file, the text file that indicates the molecule used for training and the norm file used to normalize the target values. 

The MolDataset class without any transform return a dictionary that contain the elements and their coordinates. We use GNNTransformQM class to transform our data to a graph that can be used by a GNN. The parameter post_transform is another transformation used to perform data augmentation.

In [13]:
train = "../data/QM/splits/train_tinyQM.txt"

transform = T.RandomTranslate(0.25)
batch_size = 128
num_workers = 48

data_train = MolDataset(qmh5_file, train, target_norm_file=norm_file, transform=GNNTransformQM(), post_transform=transform)

Finally, we can load our data using the PyTorch DataLoader.

In [14]:
train_loader = DataLoader(data_train, batch_size, shuffle=True, num_workers=0)

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[1602, 25], edge_index=[2, 30354], edge_attr=[30354, 1], y=[60], pos=[1602, 3], id=[30], batch=[1602], ptr=[31])


### PyTorch lightning 

The QMDataModule is a class inherated from LightningDataModule that instanciate the MolDataset for training, validation and test set and returns a dataloader for each set. 

We start by instanciation of the QMDataModule

In [15]:
files_root =  "../data/QM"

qmh5file = "h5_files/tiny_qm.hdf5"

tr = "splits/train_tinyQM.txt"
v = "splits/val_tinyQM.txt"
te = "splits/test_tinyQM.txt"

qmdata = QMDataModule(files_root, h5file=qmh5file, train=tr, val=v, test=te, num_workers=0)

Then, we call the setup function to instanciate the MolDataset for training, validation and test set

In [16]:
qmdata.setup()

Finally, we can return a dataloader for each set.

In [17]:
train_loader = qmdata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break
    

DataBatch(x=[1602, 25], edge_index=[2, 30354], edge_attr=[30354, 1], y=[60], pos=[1602, 3], id=[30], batch=[1602], ptr=[31])


# MD dataset

We generated a tiny h5 file that can be inspected right away. We do this for the structure with pdb-id 10GS.

In [23]:
mdh5_file_tiny = '../data/MD/h5_files/tiny_md.hdf5'
md_H5File_tiny = h5py.File(mdh5_file_tiny)

In [24]:
md_H5File_tiny['10GS'].keys()

<KeysViewHDF5 ['atoms_element', 'atoms_number', 'atoms_residue', 'atoms_type', 'frames_bSASA', 'frames_distance', 'frames_interaction_energy', 'frames_rmsd_ligand', 'molecules_begin_atom_index', 'trajectory_coordinates']>

The beginning of the name of each property indicates the respective shape:
- atoms_ have a entry for each atom of the structure
- frames_ have an entry for each of the 100 frames
- molecules_ has an entry for each molecule, including the ligand
- trajectory_coordinates_ has an entry of each atom and each frame

In [25]:
[(key, np.shape(md_H5File_tiny['10GS'][key])) for key in md_H5File_tiny['10GS'].keys()]

[('atoms_element', (6593,)),
 ('atoms_number', (6593,)),
 ('atoms_residue', (6593,)),
 ('atoms_type', (6593,)),
 ('frames_bSASA', (100,)),
 ('frames_distance', (100,)),
 ('frames_interaction_energy', (100,)),
 ('frames_rmsd_ligand', (100,)),
 ('molecules_begin_atom_index', (3,)),
 ('trajectory_coordinates', (100, 6593, 3))]

To run models for the MD dataset you will most likely need to preprocess the h5 file based on your model. We provide a preprocessing script (see data/processing/preprocessing_db.py) that can filter out the atom types that you are not interested in (e.g. H-atoms) or calculate values of interest based on your models.
Here, we will show how to use the script to calculate the adaptability values on the dataset and stripping the H-atoms.  
In this notebook we define a new Args class, if you use the script in the terminal just provide these values as input parameters in the command line.

In [9]:
class Args:
    # input file
    datasetIn = "../data/MD/h5_files/tiny_md.hdf5"
    # Feature that should be stripped, e.g. atoms_element or atoms_type
    strip_feature = "atoms_element"
    # Value to strip, e.g. if strip_freature= atoms_element; 1 for H. 
    strip_value = 1
    # Start index of structures
    begin = 0
    # End index of structures
    end = 20 
    # We calculate the adaptability for each atom. 
    # Default behaviour will also strip H atoms, if no stripping should be perfomed set strip_value to -1.
    Adaptability = True
    # If set to True this will create a new feature that combines one entry for each protein AA but all ligand entries; 
    # e.g. for only ca set strip_feature = atoms_type and strip_value = 14
    Pres_Lat = False
    # We strip the complex by given distance (in Angstrom) from COG of molecule, 
    # use e.g. 15.0. If default value is given (0.0) no pocket stripping will be applied.
    Pocket = 0.0
    # output file name and location
    datasetOut = "../data/MD/h5_files/tiny_md_out.hdf5"


args = Args()

preprocessing_db.main(args)

../data/MD/h5_files/tiny_md_out.hdf5
Removing existing output file...
10GS 1
Stripping  atoms_element 1  and calculating adaptability for the atoms that were not stripped.
11GS 2
13GS 3
16PK 4
184L 5
185L 6
186L 7
187L 8
188L 9
1A07 10
1A08 11
1A09 12
1A0Q 13
1A1B 14
1A1C 15
1A1E 16
1A28 17
1A2C 18
1A30 19
1A3E 20


The same steps used for QM can be used to load the MD dataset. We start by loading the generated h5 file.

In [26]:
files_root =  ""

mdh5_file = '../data/MD/h5_files/tiny_md_out.hdf5'

train_idx = "../data/MD/splits/train_tinyMD.txt"
val_idx = "../data/MD/splits/val_tinyMD.txt"
test_idx = "../data/MD/splits/test_tinyMD.txt"

md_H5File = h5py.File(mdh5_file)

During preprocessing the H-atoms were stripped (see the change in atoms_ shape) and a new feature, the adaptability was calculated for each atom.

In [27]:
[(key, np.shape(md_H5File['10GS'][key])) for key in md_H5File['10GS'].keys()]

[('atoms_coordinates_ref', (3295, 3)),
 ('atoms_element', (3295,)),
 ('atoms_number', (3295,)),
 ('atoms_residue', (3295,)),
 ('atoms_type', (3295,)),
 ('feature_atoms_adaptability', (3295,)),
 ('frames_bSASA', (100,)),
 ('frames_distance', (100,)),
 ('frames_interaction_energy', (100,)),
 ('frames_rmsd_ligand', (100,)),
 ('molecules_begin_atom_index', (3,)),
 ('trajectory_coordinates', (100, 3295, 3))]

In [28]:
# Atom's coordinates from the first frame 
xyz = md_H5File['10GS']['trajectory_coordinates'][0, :, :] 

We can now initiate the dataloader.

In [29]:
train_dataset = ProtDataset(mdh5_file, idx_file=train_idx, transform=GNNTransformMD(), post_transform=T.RandomTranslate(0.05))

train_loader = DataLoader(train_dataset, batch_size=16, num_workers=48)

In [30]:
for idx, val in enumerate(train_loader):
    print(val)
    break

KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user/micromamba/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user/micromamba/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/user/micromamba/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/tillsiebenmorgen/Projects/MiSaTo-dataset/src/data/components/datasets.py", line 57, in __getitem__
    item["scores"] = pitem["atoms_soft_hard"][:][:cutoff]
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "/home/user/micromamba/lib/python3.10/site-packages/h5py/_hl/group.py", line 357, in __getitem__
    oid = h5o.open(self.id, self._e(name), lapl=self._lapl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5o.pyx", line 190, in h5py.h5o.open
KeyError: "Unable to open object (object 'atoms_soft_hard' doesn't exist)"


In [50]:
mddata = MDDataModule(files_root, h5file=mdh5_file, train=train_idx, val=val_idx, test=test_idx, num_workers=0)

In [51]:
mddata.setup()

In [52]:
train_loader = mddata.train_dataloader()

for idx, val in enumerate(train_loader):
    print(val)
    break

DataBatch(x=[38378, 11], edge_index=[2, 622332], edge_attr=[622332], y=[38378], pos=[38378, 3], ids=[16], batch=[38378], ptr=[17])
