In [1]:
import torch
import sys
sys.path.append(".") # Set pathway to dataset_prep and VarBatchSampler code here
import dataset_prep as dsp
import VarBatchSampler as vbs



Load the benchmarking dataset as an example

In [2]:
# Set pathway to dataset
path = '.'  # Set pathway to benchmarking datasets

# Load the datasets
dataset_1 = torch.load("{}/SPAF_dataset_benchmark_pt1.pt".format(path))
dataset_2 = torch.load("{}/SPAF_dataset_benchmark_pt2.pt".format(path))

# Combine datasets
dataset = torch.utils.data.ConcatDataset([dataset_1, dataset_2])

Print a sample from the dataset to showcase how the data is organized

Note that all datasets (lim2000, res400, and res400+) are organized in this way

In [3]:
print("Coordinates (in nm):")
print("---------------------------------------------")
print(dataset[-1][0])

print("\n\nGBn2 Solvation free energy")
print("E_\mathrm{GBn2} (in kJ/mol):")
print("---------------------------------------------")
print(dataset[-1][1])

print("\n\nGBn2 Forces")
print(r" -\nabla_{x} E_\mathrm{GBn2} (in kJ/mol/nm):")
print("---------------------------------------------")
print(dataset[-1][2])

print("\n\nElement types (listed as atomic numbers)")
print("---------------------------------------------")
print(dataset[-1][3])

print("\n\nCHARMM36 atom types (numeric embeddings)")
print("---------------------------------------------")
print(dataset[-1][4])

print("\n\nNumber of atoms in sample")
print("---------------------------------------------")
print(dataset[-1][5])

print("\n\nUnitprot ID of sample")
print("---------------------------------------------")
print(dataset[-1][6])

Coordinates (in nm):
---------------------------------------------
tensor([[-1.9458,  1.2914, -0.4511],
        [-1.8936,  1.3812, -0.4466],
        [-1.9597,  1.2669, -0.5512],
        ...,
        [ 2.3217, -2.1743,  1.2122],
        [ 2.2697, -2.0944,  1.2234],
        [ 2.4905, -2.2721,  0.8994]])


GBn2 Solvation free energy
E_\mathrm{GBn2} (in kJ/mol):
---------------------------------------------
tensor([-76801.7188])


GBn2 Forces
 -\nabla_{x} E_\mathrm{GBn2} (in kJ/mol/nm):
---------------------------------------------
tensor([[  20.9618, -123.1268,   83.9408],
        [-144.7246,  -91.0444,  -60.6401],
        [  36.0034,  126.3221,   63.0963],
        ...,
        [ 879.9111, -335.8373, -204.7123],
        [-641.9651,  383.8910,  147.7843],
        [ 561.6282,  252.5732,   66.0976]])


Element types (listed as atomic numbers)
---------------------------------------------
tensor([7, 1, 1,  ..., 8, 1, 8])


CHARMM36 atom types (numeric embeddings)
-----------------------------

Example using the variable-sized batch sampler

As detailed in our paper, "variable-sized" here means that the number of samples in each batch will vary based on a user-set total number of atoms of all samples in the batch

In [4]:
# Set atom limit for the variable-sized batch sampler
n_atom_lim = 20000

# Set distributed sampler 
# (need an initial sampler to feed to our code, this serves that purpose)
sampler = torch.utils.data.distributed.DistributedSampler(dataset, 
                                                          num_replicas=1, 
                                                          rank=0
                                                         )

# Print length of initial sampler
print(len(sampler))

# Initialize collation func
coll_fn = dsp.collation_func(embed_type='names', # use C36 atom type embedding
                             unit='kJ/mol', 
                             use_forces=False  # Set true to return forces and energies
                            )

# Build the variable batch loader
sampler = vbs.VarBatchSampler(sampler, 
                              max_batch_size=20, 
                              n_atom_limit=20000, 
                              dataset=dataset
                             )

# Print length of variable batch sampler
print(len(sampler))

# Create dataloader using sampler
loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn = coll_fn,
                                     batch_sampler = sampler
                                    )

# Print the length of the dataloader
print(len(loader))

560
224
224


Loop through the dataloader

In [5]:
for i, batch in enumerate(loader):
    continue

Print batch to show structure of data from dataloader

In [6]:
print("Coordinates (in nm):")
print("---------------------------------------------")
print(batch[0])

print("\n\nGBn2 Solvation free energy")
print("E_\mathrm{GBn2} (in kJ/mol):")
print("---------------------------------------------")
print(batch[1])

print("\n\nCHARMM36 atom types (numeric embeddings)")
print("---------------------------------------------")
print(batch[2])

print("\n\nNumber of atoms in sample")
print("---------------------------------------------")
print(batch[3])

print("\n\nBatch tensor indicating which sample")
print("each atom belongs to")
print("---------------------------------------------")
print(batch[4])


Coordinates (in nm):
---------------------------------------------
tensor([[-0.9248, -0.7053, -2.3402],
        [-0.8331, -0.7499, -2.3612],
        [-0.9931, -0.7796, -2.3151],
        ...,
        [-0.8798, -2.0542,  2.7118],
        [-0.5129, -2.2165,  2.7655],
        [-0.4584, -2.0077,  2.8084]])


GBn2 Solvation free energy
E_\mathrm{GBn2} (in kJ/mol):
---------------------------------------------
tensor([-20616.3789, -38702.4258, -23591.6133])


CHARMM36 atom types (numeric embeddings)
---------------------------------------------
tensor([63, 17, 19,  ..., 27, 72, 80])


Number of atoms in sample
---------------------------------------------
tensor([4808, 6196, 8037])


Batch tensor indicating which sample
each atom belongs to
---------------------------------------------
tensor([0, 0, 0,  ..., 2, 2, 2])
