# Working with Protein Tensors in Graphein

This tutorial demonstrates working with protein tensors consistent with the PyTorch Geometric API which is supported as of Graphein 1.6.0.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/a-r-j/graphein/blob/master/notebooks/protein_tensors.ipynb) [![GitHub](https://img.shields.io/badge/-View%20on%20GitHub-181717?logo=github&logoColor=ffffff)](https://github.com/a-r-j/graphein/blob/master/notebooks/protein_tensors.ipynb)

In [32]:
import graphein
graphein.verbose(False)

In [33]:
%load_ext watermark
%watermark
print("Graphein version: ", graphein.__version__)

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: 2023-03-30T23:14:11.142537+01:00

Python implementation: CPython
Python version       : 3.9.16
IPython version      : 8.11.0

Compiler    : GCC 11.3.0
OS          : Linux
Release     : 5.15.0-67-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 64
Architecture: 64bit

Graphein version:  1.6.0


## 1. Parsing proteins to Tensors

In [graphein.protein.tensor.io] there are several functions for parsing Pandas dataframes from BioPandas into tensors:

We'll use the following example:


In [34]:
from biopandas.pdb import PandasPdb

p = PandasPdb().fetch_pdb("4hhb")
df = p.df["ATOM"]
df.head()

Unnamed: 0,record_name,atom_number,blank_1,atom_name,alt_loc,residue_name,blank_2,chain_id,residue_number,insertion,...,x_coord,y_coord,z_coord,occupancy,b_factor,blank_4,segment_id,element_symbol,charge,line_idx
0,ATOM,1,,N,,VAL,,A,1,,...,19.323,29.727,42.781,1.0,49.05,,,N,,883
1,ATOM,2,,CA,,VAL,,A,1,,...,20.141,30.469,42.414,1.0,43.14,,,C,,884
2,ATOM,3,,C,,VAL,,A,1,,...,21.664,29.857,42.548,1.0,24.8,,,C,,885
3,ATOM,4,,O,,VAL,,A,1,,...,21.985,29.541,43.704,1.0,37.68,,,O,,886
4,ATOM,5,,CB,,VAL,,A,1,,...,19.887,31.918,43.524,1.0,72.12,,,C,,887


In [35]:
import graphein.protein.tensor as gpt
coords = gpt.io.protein_df_to_tensor(df)

print(coords.shape)
gpt.plot.plot_structure(coords)

torch.Size([574, 37, 3])


This parses the atomic coordinates into a tensor of shape (Num Residues x Atom Types (default = 37) x 3). The number of the atoms is determined by:

In [36]:
from graphein.protein.resi_atoms import ATOM_NUMBERING
ATOM_NUMBERING

{'N': 0,
 'CA': 1,
 'C': 2,
 'O': 3,
 'CB': 4,
 'OG': 5,
 'CG': 6,
 'CD1': 7,
 'CD2': 8,
 'CE1': 9,
 'CE2': 10,
 'CZ': 11,
 'OD1': 12,
 'ND2': 13,
 'CG1': 14,
 'CG2': 15,
 'CD': 16,
 'CE': 17,
 'NZ': 18,
 'OD2': 19,
 'OE1': 20,
 'NE2': 21,
 'OE2': 22,
 'OH': 23,
 'NE': 24,
 'NH1': 25,
 'NH2': 26,
 'OG1': 27,
 'SD': 28,
 'ND1': 29,
 'SG': 30,
 'NE1': 31,
 'CE3': 32,
 'CZ2': 33,
 'CZ3': 34,
 'CH2': 35,
 'OXT': 36}

Similarly, we can get a tensor indicating chain membership:

In [37]:
chains = gpt.io.protein_df_to_chain_tensor(df)
print(chains.shape)
print(chains)

chains = gpt.io.protein_df_to_chain_tensor(df, one_hot=True)
print(chains)
print(chains.shape)

torch.Size([574])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
        2, 2, 2, 2, 2,

## 2. Parsing to PyG ``Data``

We can also parse a pdb code, pdb path or uniprot ID (via AF2 API) to a PyTorch Geometric ``Data`` object:

In [38]:
import graphein.protein.tensor as gpt
data = gpt.io.protein_to_pyg(
    pdb_code="4hhb", # Can alternatively pass a path or a uniprot ID (for AF2) with pdb_path=... and uniprot_id=...
    chain_selection="ABCD", # Select all 4 chains
    deprotonate=True, # Deprotonate the structure
    keep_insertions=False, # Remove insertions
    keep_hets=[], # Remove HETATMs
    model_index=1, # Select the first model
    # Can select a subset of atoms with atom_types=...
    )
print("Data: ", data)
print("ID: ", data.id)
print("Residues: ", data.residues[:10])
print("Residue IDs: ", data.residue_id[:10])

Data:  Data(coords=[574, 37, 3], residues=[574], id='4hhb_ABCD', residue_id=[574], residue_type=[574], chains=[574])
ID:  4hhb_ABCD
Residues:  ['VAL', 'LEU', 'SER', 'PRO', 'ALA', 'ASP', 'LYS', 'THR', 'ASN', 'VAL']
Residue IDs:  ['A:VAL:1', 'A:LEU:2', 'A:SER:3', 'A:PRO:4', 'A:ALA:5', 'A:ASP:6', 'A:LYS:7', 'A:THR:8', 'A:ASN:9', 'A:VAL:10']


In [39]:
from torch_geometric.data import Batch

batch = Batch.from_data_list([data, data, data])
batch

DataBatch(coords=[1722, 37, 3], residues=[3], id=[3], residue_id=[3], residue_type=[1722], chains=[1722])

## 3. Parsing to Graphein ``Protein``

Graphein provides a Protein class which wraps the PyG data classes with useful protein-specific methods. There are several ways to parse a structure into this format:

In [40]:
from rich import inspect
inspect(gpt.Protein)

In [41]:
import graphein.protein.tensor as gpt

# 1. From a PDB code
protein = gpt.Protein()
protein.from_pdb_code("4hhb", chain_selection="ABCD")

# 2. From a PDB file
protein = gpt.Protein()
#protein.from_pdb_file("./examples/pdbs/3eiy.pdb")

# 3. From a BioPandas DataFrame
p = PandasPdb().fetch_pdb("4hhb")
df = p.df["ATOM"]
protein = gpt.Protein().from_dataframe(df)

# 4. From a PyG Data object
data = gpt.io.protein_to_pyg(
    pdb_code="4hhb", # Can alternatively pass a path or a uniprot ID (for AF2) with pdb_path=... and uniprot_id=...
    chain_selection="ABCD", # Select all 4 chains
    deprotonate=True, # Deprotonate the structure
    keep_insertions=False, # Remove insertions
    keep_hets=[], # Remove HETATMs
    model_index=1, # Select the first model
    # Can select a subset of atoms with atom_types=...
    )
protein = gpt.Protein().from_data(data)

print(protein)

Protein(fill_value=1e-05, atom_list=[37], residue_type=[574], id='4hhb_ABCD', residue_id=[574], chains=[574], coords=[574, 37, 3], residues=[574])


### Selecting Views and Representations

In [42]:
protein.alpha_carbon(cache="ca")
protein.full_atom_coords(cache="fa")
protein.backbone_frames(cache="frame")

print("Alpha Carbon coordinates: ", protein.ca.shape)
print("Backbone coordinates: ", protein.backbone().shape) # TODO
print("Full atom coords: ", protein.fa[0].shape)
print("Backbone Frames: (rotations) ", protein.frame[0].shape)
print("Backbone Frames: (translation) ", protein.frame[1].shape)

Alpha Carbon coordinates:  torch.Size([574, 3])
Backbone coordinates:  torch.Size([574, 4, 3])
Full atom coords:  torch.Size([4384, 3])
Backbone Frames: (rotations)  torch.Size([574, 3, 3])
Backbone Frames: (translation)  torch.Size([574, 1, 3])


### Computing Edges

In [43]:
protein.edges("eps_8", cache="eps")
protein.edges("knn_16", cache="knn")

print("Radius Edges: ", protein.eps.shape)
print("KNN Edges: ", protein.knn.shape)

protein.edges("eps_8", cache="eps", loop=True)
protein.edges("knn_16", cache="knn", loop=True)

print("Radius Edges: ", protein.eps.shape)
print("KNN Edges: ", protein.knn.shape)

Radius Edges:  torch.Size([2, 5634])
KNN Edges:  torch.Size([2, 9184])
Radius Edges:  torch.Size([2, 6208])
KNN Edges:  torch.Size([2, 9184])


In [44]:
### Computing Edge Distances
protein.edge_distances(protein.ca, protein.knn).shape

torch.Size([9184])

### Computing Features

In [45]:
# Angles
print("Dihedral angles:" , protein.dihedrals())
print("Sidechain Torsion angles: ", protein.sidechain_torsion())
print("Kappa: ", protein.kappa(rad=False, embed=False))
print("Alpha", protein.alpha(embed=False))

Dihedral angles: tensor([[ 1.0000,  0.0000,  0.9463,  0.3232,  0.1277, -0.9918],
        [ 0.2636,  0.9646, -0.5480,  0.8365, -0.0836,  0.9965],
        [-0.4605,  0.8877,  0.9947, -0.1024, -0.9744,  0.2249],
        ...,
        [-0.8991, -0.4378, -0.5160,  0.8566, -0.6029, -0.7978],
        [-0.9993,  0.0365, -0.0748, -0.9972,  0.9891, -0.1474],
        [-0.6320,  0.7750,  1.0000,  0.0000,  1.0000,  0.0000]])
Sidechain Torsion angles:  tensor([[ 0.9987, -0.0519,  1.0000,  ...,  0.0000,  1.0000,  0.0000],
        [ 0.9996, -0.0280,  1.0000,  ...,  0.0000,  1.0000,  0.0000],
        [ 0.9997,  0.0230,  1.0000,  ...,  0.0000,  1.0000,  0.0000],
        ...,
        [ 1.0000, -0.0071,  0.9994,  ..., -0.0541,  0.9991,  0.0427],
        [ 0.9997, -0.0246,  0.9994,  ...,  0.0000,  1.0000,  0.0000],
        [ 0.9999,  0.0168,  0.9995,  ...,  0.0000,  1.0000,  0.0000]])
Kappa:  tensor([  0.0000,   0.0000,  29.4452, 124.3686, 102.3605, 105.1655, 107.3162,
        111.5782, 113.1725, 113.7545, 

## Plotting

We provide several plotting utilities built with Plotly. This means plots can easily be logged to WandB.

In [46]:
protein.plot_structure()

In [47]:
protein.plot_distance_matrix()

In [48]:
protein.plot_dihedrals()

### Testing

In [49]:
# Check if the protein has a complete backbone
protein.has_complete_backbone()

tensor(True)

In [50]:
# Checks all expected atoms are present for each residue
protein.is_complete()

True

## ProteinBatch

In [51]:
import graphein.protein.tensor as gpt

batch = gpt.data.ProteinBatch().from_pdb_codes(pdb_codes=["3eiy", "4hhb", "1a0q"])
print(batch)

DataProteinBatch(fill_value=[3], atom_list=[3], residue_type=[1159], id=[3], residue_id=[3], chains=[1159], coords=[1159, 37, 3], residues=[3])


In [52]:
batch.plot_structure()

### Unbatching and apply

We can quickly unbatch a ``ProteinBatch`` into a list of ``Protein``s

In [53]:
proteins = batch.to_protein_list()
proteins

[Protein(fill_value=[1], atom_list=[37], residue_type=[174], id='3eiy', residue_id=[174], chains=[174], coords=[174, 37, 3], residues=[174]),
 Protein(fill_value=[1], atom_list=[37], residue_type=[574], id='4hhb', residue_id=[574], chains=[574], coords=[574, 37, 3], residues=[574]),
 Protein(fill_value=[1], atom_list=[37], residue_type=[411], id='1a0q', residue_id=[411], chains=[411], coords=[411, 37, 3], residues=[411])]

We can also define a function that operates on a singe ``Protein`` and apply it to all proteins in the ``ProteinBatch`` using the ``protein_apply()`` method:

In [54]:
from graphein.protein.tensor.plot import plot_structure

def single_plot(protein: gpt.Protein):
    return plot_structure(protein.coords, lines=False)

plots = batch.protein_apply(single_plot)
plots[2]

If we want to apply the function to a single protein, we can use ``apply_to`` with a specified index:

In [55]:
batch.apply_to(single_plot, 0)

In [56]:
batch.apply_to(single_plot, 1)

## Saving ``Protein``s and ``ProteinBatch``es

### Saving as Tensor

In [57]:
import torch
import graphein.protein.tensor as gpt

# Save a protein
protein = gpt.data.get_random_protein()
torch.save(protein, "test_protein.pt")
protein.save("test_protein_2.pt")

batch = gpt.data.get_random_batch(16)
torch.save(batch, "test_batch.pt")
batch.save("test_batch_2.pt")

In [58]:
print(torch.load("test_protein.pt"))
print(torch.load("test_protein.pt"))

Protein(fill_value=1e-05, atom_list=[37], residue_type=[574], id='4hhb', residue_id=[574], chains=[574], coords=[574, 37, 3], residues=[574])
Protein(fill_value=1e-05, atom_list=[37], residue_type=[574], id='4hhb', residue_id=[574], chains=[574], coords=[574, 37, 3], residues=[574])


In [59]:
print(torch.load("test_batch.pt"))
print(torch.load("test_batch_2.pt"))

DataProteinBatch(fill_value=[16], atom_list=[16], residue_type=[5395], id=[16], residue_id=[16], chains=[5395], coords=[5395, 37, 3], residues=[16])
DataProteinBatch(fill_value=[16], atom_list=[16], residue_type=[5395], id=[16], residue_id=[16], chains=[5395], coords=[5395, 37, 3], residues=[16])


### To PDB Files

*N.B.* if you work with PDBs in VSCode you may find [this extension](https://marketplace.visualstudio.com/items?itemName=ArianJamasb.protein-viewer) useful :+) 

In [60]:
import torch
import graphein.protein.tensor as gpt

# Save a protein
protein = gpt.data.get_random_protein()
protein.to_pdb("test_protein.pdb")

In [61]:
from biopandas.pdb import PandasPdb

p = PandasPdb().read_pdb("test_protein.pdb")
p.df["ATOM"].head()

Unnamed: 0,record_name,atom_number,blank_1,atom_name,alt_loc,residue_name,blank_2,chain_id,residue_number,insertion,...,x_coord,y_coord,z_coord,occupancy,b_factor,blank_4,segment_id,element_symbol,charge,line_idx
0,ATOM,1,,N,,GLN,,A,1,,...,25.569,27.14,-2.829,1.0,0.0,,,N,0.0,0
1,ATOM,2,,CA,,GLN,,A,1,,...,26.022,27.178,-1.417,1.0,0.0,,,C,0.0,1
2,ATOM,3,,C,,GLN,,A,1,,...,25.202,26.176,-0.607,1.0,0.0,,,C,0.0,2
3,ATOM,4,,O,,GLN,,A,1,,...,24.418,25.401,-1.149,1.0,0.0,,,O,0.0,3
4,ATOM,5,,CB,,GLN,,A,1,,...,25.856,28.603,-0.862,1.0,0.0,,,C,0.0,4


## To DataFrame

In [62]:
import graphein.protein.tensor as gpt

protein = gpt.data.get_random_protein()
df = protein.to_dataframe()
df.head()

Unnamed: 0,record_name,atom_number,blank_1,atom_name,alt_loc,residue_name,blank_2,chain_id,residue_number,insertion,...,x_coord,y_coord,z_coord,occupancy,b_factor,blank_4,segment_id,element_symbol,charge,line_idx
0,ATOM,1,,N,,GLN,,0,1,,...,25.569,27.139999,-2.829,1.0,0.0,,,N,0,1
1,ATOM,2,,CA,,GLN,,0,1,,...,26.021999,27.177999,-1.417,1.0,0.0,,,C,0,2
2,ATOM,3,,C,,GLN,,0,1,,...,25.202,26.176001,-0.607,1.0,0.0,,,C,0,3
3,ATOM,4,,O,,GLN,,0,1,,...,24.417999,25.400999,-1.149,1.0,0.0,,,O,0,4
4,ATOM,5,,CB,,GLN,,0,1,,...,25.856001,28.603001,-0.862,1.0,0.0,,,C,0,5
