# Getting started with datasets

[<img src="https://colab.research.google.com/assets/colab-badge.svg">](https://colab.research.google.com/github/BorgwardtLab/proteinshake/blob/main/docs/readthedocs/source/notebooks/dataset.ipynb)

ProteinShake implements all the steps from raw protein coordinate files (PDB/mmCIF) to training-ready dataset.
We also host the result of these computations for several datasets which are listed [here](modules/datasets.rst).

You can obtain a dataset ready for model training in one line. 

Here is an example of how you would get a `torch_geometric.Dataset` object of proteins pulled from the RCSB PDB Data Bank as epsilon graphs.

In [1]:
# If you are running on colab, uncomment the line below and run the cell to install ProteinShake

#!pip install proteinshake


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [6]:
from proteinshake.datasets import RCSBDataset

# a graph dataset with epsilon-neighborhood graphs with radius 8 Angstrom, in DGL
dataset = RCSBDataset(root='./data', verbosity=1).to_graph(eps=8).pyg()

ModuleNotFoundError: No module named 'proteinshake'

The above code executes the three main steps of dataset preparation:

1. Loading the processed protein data: `RCSBDataset(root='data')`
2. Converting the proteins to your representation of choice: `.to_graph(eps=8)`
3. Converting the dataset to your framework of choice: `.pyg()`

To reproduce the processing you can pass the `use_precomputed=False` flag to `RCSBDataset()`.
This executes all the processing steps locally, whereas by default we try to fetch the dataset from the datasets we host on Zenodo, as the processing can be quite time-consuming.

Next, we break down the three steps into some more detail.

## Loading protein data

The first step in the snippet above does most of the leg work.
Once the dataset object is created, it holds an iterable of dictionaries, one for each protein in the dataset which is accessed through the `Dataset.proteins()` method.

In [2]:
dataset = RCSBDataset(root='./data', verbosity=1)
proteins = dataset.proteins(resolution='residue')
print(next(proteins)['protein'])

{'ID': '2NXC', 'sequence': 'MWVYRLKGTLEALDPILPGLFDGGARGLWEREGEVWAFFPAPVDLPYEGVWEEVGDEDWLEAWRRDLKPALAPPFVVLAPWHTWEGAEIPLVIEPGGHHETTRLALKALARHLRPGDKVLDLGTGSGVLAIAAEKLGGKALGVDIDPMVLPQAEANAKRNGVRPRFLEGSLEAALPFGPFDLLVANLYAELHAALAPRYREALVPGGRALLTGILKDRAPLVREAMAGAGFRPLEEAAEGEWVLLAYGR'}


Different implementations of the `Dataset` parent class let you customize this step.
For example, the `RCSBDataset` accepts a `from_list` argument which lets you specify which PDBs to fetch.

## Protein representations

Once the processed protein data is in the `Dataset` object we need to convert it to a representation that works with downstream deep learning models.
We currently support graphs, point clouds, and voxels.

In [3]:
point_dataset = dataset.to_point()
graph_dataset = dataset.to_graph(eps=8)
voxel_dataset = dataset.to_voxel()

These methods can be applied to any `Dataset` subclass and perform the necessary computations for the different representations.
Some notes on each representation:

### Point Clouds

Point clouds simply return the x, y, z coordinates of the alpha carbon for each residue.
You can finds some relevant processing in the [transforms module](modules/transforms.rst) such as centering and rotating.
For example, to center the point clouds:

In [4]:
from proteinshake.transforms import CenterTransform

points = RCSBDataset(root='./data_transformed', verbosity=1).to_point(transform=CenterTransform())

### Graphs

Graph construction can be done in two ways: epsilon or k nearest neighbors.
The epsilon graph is chosen when `eps` is passed with a distance threshold.
All pairs of residues within the distance threshold are connected by an edge.
If `k=4` is passed then each residue is connected by an edge to its 4 nearest neighbors.

In [5]:
knn_graph = dataset.to_graph(k=4)
eps_graph = dataset.to_graph(eps=8)

You can obtain a weighted graph where weights correspond to the distance between connected residues:

In [6]:
eps_graph = dataset.to_graph(eps=8, weighted_edges=True)

### Voxels

For the voxel representation we place a 3D grid of voxels over the protein and indicate the residue/atom occupancy of voxels with a one-hot encoding of the amino acid or atom types present at the each voxel. You can choose how this embedding is aggregated if multiple residues are in a voxel (`mean` or `sum`):

In [7]:
voxel = dataset.to_voxel(voxelsize=10, aggregation='mean')

By default, all voxel grids are padded to the largest protein, but by providing the `gridsize` you can fix this to a constant size. Larger proteins will be truncated.