# Learning at Proteome Scale!

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/foldcomp.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/a-r-j/graphein/blob/master/notebooks/foldcomp.ipynb)

The wonderful folks in the [Steinegger group](https://steineggerlab.com/en/) have developed [FoldComp](https://github.com/steineggerlab/foldcomp), an excellent tool for managing and compressing databases of predicted structures.

![FoldComp](https://raw.githubusercontent.com/steineggerlab/foldcomp/master/.github/img/format_benchmark_light.png)


We can use this as a helpful backend for managing massive protein structure datasets with an order-of-magnitude disk space saving.


The FoldCompDataset class in Graphein enables parsing to both PyG Data and Graphein Protein/ProteinBatch objects ([see this tutorial for more](https://github.com/a-r-j/graphein/blob/master/notebooks/protein_tensors.ipynb)).





In [27]:
from graphein.ml.datasets.foldcomp_dataset import FoldCompDataset
from rich import inspect

inspect(FoldCompDataset)

In [26]:
ds = FoldCompDataset(
    root="./test_data/", # Set directory path to download data to
    database="afdb_swissprot_v4", # Name of the database. See: https://github.com/steineggerlab/foldcomp
    ids=None, # List of IDs to include in the dataset. If None, all IDs are included.
    fraction=1, # Fraction of database to use
    use_graphein=True # Whether or not to use Graphein's Protein objects or to use standard PyG Data
)

print("Number of proteins in dataset: ", len(ds))

100%|██████████| 542378/542378 [00:00<00:00, 3310570.45it/s]


Processing...


Number of proteins in dataset:  542378


Done!


In [27]:
for i, protein in enumerate(ds):
    print(protein)
    if i == 10: break

Protein(fill_value=1e-05, atom_list=[37], id='AF-P42477-F1-model_v4.pdb', x=[400, 37, 3], residues=[400], chains=[400], residue_id=[400], residue_type=[400])
Protein(fill_value=1e-05, atom_list=[37], id='AF-Q758U2-F1-model_v4.pdb', x=[383, 37, 3], residues=[383], chains=[383], residue_id=[383], residue_type=[383])
Protein(fill_value=1e-05, atom_list=[37], id='AF-B2TY17-F1-model_v4.pdb', x=[97, 37, 3], residues=[97], chains=[97], residue_id=[97], residue_type=[97])
Protein(fill_value=1e-05, atom_list=[37], id='AF-B7XIW9-F1-model_v4.pdb', x=[224, 37, 3], residues=[224], chains=[224], residue_id=[224], residue_type=[224])
Protein(fill_value=1e-05, atom_list=[37], id='AF-Q2NYM2-F1-model_v4.pdb', x=[297, 37, 3], residues=[297], chains=[297], residue_id=[297], residue_type=[297])
Protein(fill_value=1e-05, atom_list=[37], id='AF-P36990-F1-model_v4.pdb', x=[49, 37, 3], residues=[49], chains=[49], residue_id=[49], residue_type=[49])
Protein(fill_value=1e-05, atom_list=[37], id='AF-P15431-F1-mod

# Selection
You can access proteins in two ways:

1. Using their index
2. By their ID

In [28]:
# Index
print(ds[10])
ds[10].plot_structure()

Protein(fill_value=1e-05, atom_list=[37], id='AF-B7LMI2-F1-model_v4.pdb', x=[356, 37, 3], residues=[356], chains=[356], residue_id=[356], residue_type=[356])


In [31]:
# ID
print(ds.get("AF-B7LMI2-F1-model_v4"))
ds.get("AF-B7LMI2-F1-model_v4").plot_structure()

Protein(fill_value=1e-05, atom_list=[37], id='AF-B7LMI2-F1-model_v4.pdb', x=[356, 37, 3], residues=[356], chains=[356], residue_id=[356], residue_type=[356])


## Creating a Dataloader

In [19]:
from torch_geometric.loader import DataLoader

dl = DataLoader(ds, batch_size=32, shuffle=True)

for i, batch in enumerate(dl):
    print(batch)
    if i == 5: break

ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[13272, 37, 3], residues=[32], chains=[13272], residue_id=[32], residue_type=[13272], batch=[13272], ptr=[33])
ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[11045, 37, 3], residues=[32], chains=[11045], residue_id=[32], residue_type=[11045], batch=[11045], ptr=[33])
ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[10277, 37, 3], residues=[32], chains=[10277], residue_id=[32], residue_type=[10277], batch=[10277], ptr=[33])
ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[11266, 37, 3], residues=[32], chains=[11266], residue_id=[32], residue_type=[11266], batch=[11266], ptr=[33])
ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[10598, 37, 3], residues=[32], chains=[10598], residue_id=[32], residue_type=[10598], batch=[10598], ptr=[33])
ProteinBatch(fill_value=[32], atom_list=[32], id=[32], x=[10498, 37, 3], residues=[32], chains=[10498], residue_id=[32], residue_type=[10498], batch=[10498], pt

In [20]:
batch.plot_structure()

## Pure PyG

If we set use_graphein to False, we get standard PyG data objects.

In [32]:
ds = FoldCompDataset(
    root="./test_data/",
    database="afdb_swissprot_v4",
    ids=None,
    fraction=1,
    use_graphein=False
)

100%|██████████| 542378/542378 [00:00<00:00, 3182352.35it/s]


Processing...


Done!


In [33]:
print(ds[1])
print(type(ds[1]))

Data(residue_type=[1218], x=[1218, 37, 3], residues=[1218], residue_id=[1218], id='AF-Q54WH2-F1-model_v4.pdb', fill_value=1e-05, atom_list=[37], chains=[1218])
<class 'torch_geometric.data.data.Data'>


## Plug and play with Pytorch Lightning!

We provide a PyTorch Lightning [`LightningDataModule`](https://lightning.ai/docs/pytorch/latest/data/datamodule.html) wrapper for FoldComp datasets.

In [21]:
from graphein.ml.datasets.foldcomp_dataset import FoldCompLightningDataModule
from rich import inspect

inspect(FoldCompLightningDataModule)

In [22]:
data_module = FoldCompLightningDataModule(
    data_dir="./test_data/",
    database="afdb_swissprot_v4",
    batch_size=32,
    num_workers=4
    )

data_module.setup()

100%|██████████| 542378/542378 [00:00<00:00, 3294375.73it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3300761.77it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3199872.58it/s]


Processing...


Done!


In [23]:
# Accessing dataset
print(data_module.train_ds)
print(data_module.val_ds)
print(data_module.test_ds)

print(data_module.train_dataloader())
print(data_module.val_dataloader())
print(data_module.test_dataloader())

FoldCompDataset(542378)
FoldCompDataset(542378)
FoldCompDataset(542378)
<torch_geometric.loader.dataloader.DataLoader object at 0x7f50bf730a90>
<torch_geometric.loader.dataloader.DataLoader object at 0x7f50bf730a30>
<torch_geometric.loader.dataloader.DataLoader object at 0x7f50bf730a90>


### Splitting the data

You can create splits by providing floats specifying the size of each partition, or a list of identifiers

In [24]:
# Splitting the data with partition sizes

data_module = FoldCompLightningDataModule(
    data_dir="./test_data/",
    database="afdb_swissprot_v4",
    batch_size=32,
    num_workers=4,
    train_split=0.7,
    val_split=0.2,
    test_split=0.1
    )

data_module.setup()

100%|██████████| 542378/542378 [00:00<00:00, 3289516.77it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3292825.97it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3297972.00it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3280352.50it/s]


Processing...


Done!


In [25]:
train_ids = ["AF-P64030-F1-model_v4", "AF-P64488-F1-model_v4"]
val_ids = ["AF-B2ICL0-F1-model_v4"]
test_ids = ["AF-P64488-F1-model_v4"]


data_module = FoldCompLightningDataModule(
    data_dir="./test_data/",
    database="afdb_swissprot_v4",
    batch_size=32,
    num_workers=4,
    train_split=train_ids,
    val_split=val_ids,
    test_split=test_ids
)
data_module.setup()


100%|██████████| 542378/542378 [00:00<00:00, 3306408.35it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3289488.23it/s]


Processing...


Done!


100%|██████████| 542378/542378 [00:00<00:00, 3294728.80it/s]


Processing...


Done!


In [26]:
print(data_module.train_ds)
print(data_module.train_ds[0])
print(data_module.train_ds[1])
print(data_module.val_ds)
print(data_module.val_ds[0])
print(data_module.test_ds)
print(data_module.test_ds[0])

FoldCompDataset(2)
Protein(fill_value=1e-05, atom_list=[37], id='AF-P64488-F1-model_v4.pdb', x=[119, 37, 3], residues=[119], chains=[119], residue_id=[119], residue_type=[119])
Protein(fill_value=1e-05, atom_list=[37], id='AF-P64030-F1-model_v4.pdb', x=[398, 37, 3], residues=[398], chains=[398], residue_id=[398], residue_type=[398])
FoldCompDataset()
Protein(fill_value=1e-05, atom_list=[37], id='AF-B2ICL0-F1-model_v4.pdb', x=[276, 37, 3], residues=[276], chains=[276], residue_id=[276], residue_type=[276])
FoldCompDataset()
Protein(fill_value=1e-05, atom_list=[37], id='AF-P64488-F1-model_v4.pdb', x=[119, 37, 3], residues=[119], chains=[119], residue_id=[119], residue_type=[119])
