# Datasets and DataLoaders

We provide Dataset and DataLoader classes for working with Graphein's ``Protein`` data structure. These are analagous to those found in PyTorch and PyTorch Geometric.


## 1. List Datasets

List datasets are the most straightforward for datasets that are small enough to fit in memory. They are also very flexible - all the preprocessing is offloaded to the user - and simply wrap a list of Proteins (or Data) into a dataset.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/protein_tensors.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/a-r-j/graphein/blob/master/notebooks/protein_tensors.ipynb)

In [2]:
import graphein
graphein.verbose(False)

In [3]:
# First, create a list of Protein objects
import torch
from graphein.protein.tensor.data import to_protein_mp

# Create a List of 20 Protein objects using multiprocessing
data_list = to_protein_mp(pdb_codes=["4hhb", "3eiy"] * 10, num_cores=4)

# Optional: do some processing on the data
def processing_pipeline(protein):
    # Do some processing
    # E.g. assign a random label:
    protein.y = torch.rand(1)
    return protein

data_list = [processing_pipeline(protein) for protein in data_list]

print(data_list)

100%|██████████| 20/20 [00:04<00:00,  4.32it/s]

[Protein(fill_value=1e-05, atom_list=[37], chains=[574], residue_type=[574], residues=[574], x=[574, 37, 3], residue_id=[574], id='4hhb', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[174], residue_type=[174], residues=[174], x=[174, 37, 3], residue_id=[174], id='3eiy', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[574], residue_type=[574], residues=[574], x=[574, 37, 3], residue_id=[574], id='4hhb', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[174], residue_type=[174], residues=[174], x=[174, 37, 3], residue_id=[174], id='3eiy', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[574], residue_type=[574], residues=[574], x=[574, 37, 3], residue_id=[574], id='4hhb', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[174], residue_type=[174], residues=[174], x=[174, 37, 3], residue_id=[174], id='3eiy', y=[1]), Protein(fill_value=1e-05, atom_list=[37], chains=[574], residue_type=[574], residues=[574], x=[574, 37, 3], residue_id=[574], id=




We can now wrap this into a ``ProteinListDataset``:

In [4]:
from graphein.protein.tensor.dataset import ProteinGraphListDataset

list_dataset = ProteinGraphListDataset(
    root=".",
    name="list_dataset_tests",
    data_list=data_list
    )
print("Dataset: ", list_dataset)
print("Example: ", list_dataset[4])

import os
os.listdir("./processed/")

Dataset:  ProteinGraphListDataset(20)
Example:  Protein(fill_value=[1], atom_list=[37], chains=[574], residue_type=[574], x=[574, 37, 3], id='4hhb', residue_id=[574], residues=[574], y=[1])


['pre_filter.pt',
 'data_test.pt',
 'data_list_dataset_tests.pt',
 'pre_transform.pt',
 'list_dataset_tests.pt']

In [5]:
torch.load("./processed/list_dataset_tests.pt")

(Protein(fill_value=[20], atom_list=[20], chains=[7480], residue_type=[7480], x=[7480, 37, 3], id=[20], residue_id=[20], residues=[20], y=[20]),
 defaultdict(dict,
             {'fill_value': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                      18, 19, 20]),
              'atom_list': tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                      18, 19, 20]),
              'chains': tensor([   0,  574,  748, 1322, 1496, 2070, 2244, 2818, 2992, 3566, 3740, 4314,
                      4488, 5062, 5236, 5810, 5984, 6558, 6732, 7306, 7480]),
              'residue_type': tensor([   0,  574,  748, 1322, 1496, 2070, 2244, 2818, 2992, 3566, 3740, 4314,
                      4488, 5062, 5236, 5810, 5984, 6558, 6732, 7306, 7480]),
              'x': tensor([   0,  574,  748, 1322, 1496, 2070, 2244, 2818, 2992, 3566, 3740, 4314,
                      4488, 5062, 5236, 5810, 5984, 6558, 6732, 7306, 7480]),
  

## 2. Datasets

### From PDB Codes

In [6]:
import os
from graphein.protein.tensor.dataset import ProteinDataset

# From PDB Codes:

ds = ProteinDataset(
    pdb_codes=["4hhb", "3eiy", "5caj"],
    root=".",
    pdb_dir="./pdbs/",
    out_dir="./processed_files/",
    overwrite=True,
    graph_labels=[torch.rand(4), torch.rand(4), torch.rand(4)],
    node_labels=[torch.rand(10), torch.rand(10), torch.rand(10)],
    chain_selections=["all", "all", "all"],
)

100%|██████████| 3/3 [00:00<00:00, 56936.25it/s]
Processing...
100%|██████████| 3/3 [00:00<00:00, 25.15it/s]
1it [00:00,  3.71it/s]
Done!


In [7]:
print("PDBs: ", os.listdir("pdbs"))
print("tensors: ", os.listdir("./processed_files/"))
print(ds[0], ds[1], ds[2])

PDBs:  ['0.pdb', '4hhb.pdb', '1.pdb', '2.pdb', '3eiy.pdb', '5caj.pdb']
tensors:  ['pre_filter.pt', '0_all.pt', '1_all.pt', 'pre_transform.pt', '5caj_all.pt', '2_all.pt', '3eiy_all.pt', '4hhb_all.pt']
Protein(fill_value=1e-05, atom_list=[37], chains=[574], residue_type=[574], residues=[574], x=[574, 37, 3], residue_id=[574], id='./pdbs//4hhb.pdb', graph_y=[4], node_y=[10]) Protein(fill_value=1e-05, atom_list=[37], chains=[174], residue_type=[174], residues=[174], x=[174, 37, 3], residue_id=[174], id='./pdbs//3eiy.pdb', graph_y=[4], node_y=[10]) Protein(fill_value=1e-05, atom_list=[37], chains=[510], residue_type=[510], residues=[510], x=[510, 37, 3], residue_id=[510], id='./pdbs//5caj.pdb', graph_y=[4], node_y=[10])


### From Sequences (via ESMFold)

(Sturctural) Datasets can be generated using ESMFold from amino acid sequences.

In [8]:
ds = ProteinDataset(
    sequences=["AGYFGMTAME", "AGYFGMTAME", "AGYFGMTAME"], # Also accepts a path to a FASTA file.
    root=".",
    pdb_dir="./pdbs/",
    out_dir="./processed_files/",
    overwrite=True,
    graph_labels=[torch.rand(1), torch.rand(1), torch.rand(1)],
    node_labels=[torch.rand(16), torch.rand(14), torch.rand(12)],
    chain_selections=["all", "all", "all"],
)

100%|██████████| 3/3 [00:29<00:00,  9.75s/it]
100%|██████████| 3/3 [00:00<00:00, 21546.08it/s]
Processing...
100%|██████████| 3/3 [00:00<00:00, 104.12it/s]
1it [00:01,  1.07s/it]
Done!


In [9]:
print("PDBs: ", os.listdir("pdbs"))
print("tensors: ", os.listdir("./processed_files/"))
print(ds[0], ds[1], ds[2])

PDBs:  ['0.pdb', '4hhb.pdb', '1.pdb', '2.pdb', '3eiy.pdb', '5caj.pdb']
tensors:  ['pre_filter.pt', '0_all.pt', '1_all.pt', 'pre_transform.pt', '5caj_all.pt', '2_all.pt', '3eiy_all.pt', '4hhb_all.pt']
Protein(fill_value=1e-05, atom_list=[37], chains=[10], residue_type=[10], residues=[10], x=[10, 37, 3], residue_id=[10], id='./pdbs//0.pdb', graph_y=[1], node_y=[16]) Protein(fill_value=1e-05, atom_list=[37], chains=[10], residue_type=[10], residues=[10], x=[10, 37, 3], residue_id=[10], id='./pdbs//1.pdb', graph_y=[1], node_y=[14]) Protein(fill_value=1e-05, atom_list=[37], chains=[10], residue_type=[10], residues=[10], x=[10, 37, 3], residue_id=[10], id='./pdbs//2.pdb', graph_y=[1], node_y=[12])


### From PDB Files

Datasets can be created from collections of local structure files.

In [1]:
import os
import torch
from graphein.protein.tensor.dataset import ProteinDataset

# From PDB Codes:

ds = ProteinDataset(
    pdb_paths=["4hhb.pdb", "3eiy.pdb", "5caj.pdb"],
    root=".",
    pdb_dir="./pdbs",
    out_dir="./processed_files/",
    overwrite=True,
    graph_labels=[torch.rand(4), torch.rand(4), torch.rand(4)],
    node_labels=[torch.rand(10), torch.rand(10), torch.rand(10)],
    chain_selections=["all", "all", "all"],
)

100%|██████████| 3/3 [00:00<00:00, 14446.51it/s]
Processing...


0it [00:00, ?it/s]

['./pdbs/4hhb.pdb', './pdbs/3eiy.pdb', './pdbs/5caj.pdb']


100%|██████████| 3/3 [00:00<00:00, 11.73it/s]
1it [00:00,  2.45it/s]
Done!


In [2]:
print("PDBs: ", os.listdir("pdbs"))
print("tensors: ", os.listdir("./processed_files/"))
print(ds[0], ds[1], ds[2])

PDBs:  ['0.pdb', '4hhb.pdb', '1.pdb', '2.pdb', '3eiy.pdb', '5caj.pdb']
tensors:  ['pre_filter.pt', '0_all.pt', '1_all.pt', 'pre_transform.pt', '5caj_all.pt', '2_all.pt', '3eiy_all.pt', '4hhb_all.pt']
Protein(fill_value=1e-05, atom_list=[37], x=[574, 37, 3], residue_type=[574], residues=[574], chains=[574], residue_id=[574], id='./pdbs/4hhb.pdb', graph_y=[4], node_y=[10]) Protein(fill_value=1e-05, atom_list=[37], x=[174, 37, 3], residue_type=[174], residues=[174], chains=[174], residue_id=[174], id='./pdbs/3eiy.pdb', graph_y=[4], node_y=[10]) Protein(fill_value=1e-05, atom_list=[37], x=[510, 37, 3], residue_type=[510], residues=[510], chains=[510], residue_id=[510], id='./pdbs/5caj.pdb', graph_y=[4], node_y=[10])


## DataLoaders

In [6]:
from graphein.protein.tensor.dataloader import ProteinDataLoader

dl = ProteinDataLoader(ds, batch_size=4, shuffle=True, num_workers=0, pin_memory=False)

print("Dataset size: ", len(ds))
print("Num batches: ", len(dl))

for batch in dl:
    print("Example batch: ", batch)
    print("Num graphs: ", batch.num_graphs)
    print("Num nodes: ", batch.num_nodes)
    break

Dataset size:  3
Num batches:  1
Example batch:  ProteinProteinBatch(fill_value=[3], atom_list=[3], x=[1258, 37, 3], residue_type=[1258], residues=[3], chains=[1258], residue_id=[3], id=[3], graph_y=[12], node_y=[30], batch=[1258], ptr=[4])
Num graphs:  3
Num nodes:  1258


In [4]:
batch.to_protein_list()

[Protein(fill_value=[1], atom_list=[37], x=[174, 37, 3], residue_type=[174], residues=[174], chains=[174], residue_id=[174], id='./pdbs/3eiy.pdb', graph_y=[4], node_y=[10]),
 Protein(fill_value=[1], atom_list=[37], x=[510, 37, 3], residue_type=[510], residues=[510], chains=[510], residue_id=[510], id='./pdbs/5caj.pdb', graph_y=[4], node_y=[10]),
 Protein(fill_value=[1], atom_list=[37], x=[574, 37, 3], residue_type=[574], residues=[574], chains=[574], residue_id=[574], id='./pdbs/4hhb.pdb', graph_y=[4], node_y=[10])]

In [5]:
batch.plot_structure()

In [7]:
batch.dihedrals()

tensor([[ 1.0000,  0.0000, -0.8665, -0.4991,  0.6185,  0.7858],
        [-0.9345, -0.3560,  1.0000,  0.0082, -0.5489, -0.8359],
        [-0.9938, -0.1114,  0.9492,  0.3146,  0.2677,  0.9635],
        ...,
        [-0.8844, -0.4668, -0.4669,  0.8843, -0.5579, -0.8299],
        [-0.9998,  0.0215, -0.0722, -0.9974,  0.9920, -0.1258],
        [-0.6148,  0.7887,  1.0000,  0.0000,  1.0000,  0.0000]])