In [None]:
# !pip install --use-pep517 "graphein[extras]" lightning torch torch-geometric tensorboard nbformat "jsonargparse[signatures]" ipywidgets tabulate

In [None]:
import graphein
graphein.verbose(enabled=False)
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.features.nodes import amino_acid as graphein_nodes
from graphein.protein import edges as graphein_edges
from graphein.protein.subgraphs import extract_subgraph
from graphein.protein.visualisation import plotly_protein_structure_graph
from functools import partial
from matplotlib import colormaps

## Converting proteins to featurized graphs

The [graphein](https://graphein.ai/) library provides functionality for producing a number of types of graph-based representations of proteins. We'll use it to construct [NetworkX](https://github.com/networkx/networkx) graphs from protein structures, extract interface residues, and to featurise the nodes and edges of the graph

Here we use the node features implemented in `graphein.protein.features.nodes.amino_acid`, but there's many more kinds of node features available in the library (see the full [API](https://graphein.ai/modules/graphein.protein.html#features))

In [None]:
graph_config = ProteinGraphConfig(
    node_metadata_functions = [graphein_nodes.amino_acid_one_hot, graphein_nodes.meiler_embedding],
    edge_construction_functions = [graphein_edges.add_peptide_bonds, partial(graphein_edges.add_distance_threshold, 
                                                                             threshold=8., 
                                                                             long_interaction_threshold=2)]
)

In [None]:
graph = construct_graph(pdb_code='1A0G', config=graph_config)

Now we have a graph object consisting of nodes and edges, each associated with the attributes we've specified

In [None]:
i = 0
for (node, node_data) in graph.nodes(data=True):
  print("Node:", node)
  print("Node attributes:", node_data)
  if i > 5:
    break
  i += 1

In [None]:
i = 0
for (start_node, end_node, edge_data) in graph.edges(data=True):
  print(f"Edge between {start_node} and {end_node}")
  print("Edge attributes:", edge_data)
  if i > 5:
    break
  i += 1

In [None]:
p = plotly_protein_structure_graph(
    graph,
    colour_edges_by="kind",
    colour_nodes_by='chain_id',
    label_node_ids=False,
    plot_title="Peptide backbone graph with distance connections.Nodes coloured by chain.",
    node_size_multiplier=1
    )
p.show()

We can extract interface residues from this graph by checking for edges between chains:

In [None]:
interface_residues = set()
for source, target, kind in graph.edges(data=True):
    if 'distance_threshold' in kind['kind']:
        if source.split(":")[0] == "A" and target.split(":")[0] != "A":
            interface_residues.add(source)
        elif target.split(":")[0] == "A" and source.split(":")[0] != "A":
            interface_residues.add(target)
interface_residues

This information can be added to the graph as an `interface_label` node feature:

In [None]:
for node, data in graph.nodes(data=True):
  if node in interface_residues:
    data['interface_label'] = 1
  else:
    data['interface_label'] = 0

Let's see where the interface is for this example:

In [None]:
p = plotly_protein_structure_graph(
    graph,
    colour_edges_by='kind',
    colour_nodes_by='interface_label',
    label_node_ids=False,
    edge_color_map=colormaps['Pastel2'],
    plot_title="Peptide backbone graph with distance connections. Nodes coloured by interface labels.",
    node_size_multiplier=1
    )
p.show()

Since our task is to predict interface residues given just one input chain, we'll extract the subgraph for the chain of interest:

In [None]:
chain_subgraph = extract_subgraph(graph, chains="A")

p = plotly_protein_structure_graph(
    chain_subgraph,
    colour_edges_by="kind",
    colour_nodes_by="interface_label",
    label_node_ids=False,
    edge_color_map=colormaps['Pastel2'],
    plot_title="Peptide backbone graph. Nodes coloured by interface_label.",
    node_size_multiplier=1
    )
p.show()

We put all this together in a function to use in the later notebooks. Feel free to add other node features, edge types, and edge features to your function.

In [None]:
def load_graph(pdb_id, chain):
    graph_config = ProteinGraphConfig(
        node_metadata_functions = [graphein_nodes.amino_acid_one_hot, graphein_nodes.meiler_embedding],
        edge_construction_functions = [graphein_edges.add_peptide_bonds, 
                                       partial(graphein_edges.add_distance_threshold, 
                                               threshold=8.,
                                               long_interaction_threshold=2)],
    )
    graph = construct_graph(pdb_code=pdb_id, config=graph_config, verbose=False)
    interface_residues = set()
    for source, target, kind in graph.edges(data=True):
        c1, c2 = source.split(":")[0], target.split(":")[0]
        if 'distance_threshold' in kind['kind']:
            if c1 == chain and c2 != chain:
                interface_residues.add(source)
            elif c2 == chain and c1 != chain:
                interface_residues.add(target)
    graph = extract_subgraph(graph, chains=chain)
    for node, data in graph.nodes(data=True):
        if node in interface_residues:
            data['interface_label'] = 1
        else:
            data['interface_label'] = 0
    return graph

## Bonus


We can also add our own edge functions or node features that are not implemented in the graphein API. For example, we can calculate the solvent accessible surface area (SASA) for each residue and include it in as a node feature.

In [None]:
from Bio.PDB.mmtf import MMTFParser
from Bio.PDB.SASA import ShrakeRupley
import warnings
warnings.filterwarnings("ignore") # to ignore warnings when parsing pdb structures

def add_sasa(pdb_id, graph):
    struct = MMTFParser.get_structure_from_url(pdb_id)
    sr = ShrakeRupley()
    sr.compute(struct, level="R") # residue level
    for _, data in graph.nodes(data=True):
        # add SASA to node features
        data['sasa'] = struct[0][data['chain_id']][data['residue_number']].sasa
    return graph

graph = load_graph("1A0G", "A")
graph = add_sasa("1A0G", graph)

In [None]:
i = 0
for (node, node_data) in graph.nodes(data=True):
    print("Node:", node)
    print("Node attributes:", node_data)
    if i > 5:
        break
    i += 1

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sasa, interface_labels = [], []
for (node, node_data) in graph.nodes(data=True):
    sasa.append(node_data['sasa'])
    interface_labels.append(node_data['interface_label'])
data = {
    "sasa": sasa,
    "interface_labels": interface_labels
}

plt.figure(figsize=(14, 8))
sns.violinplot(x="interface_labels", y="sasa", data=data)
plt.title("SASA between interface and non-interface residues")
plt.tight_layout()
plt.show()

What other node or edge features would you like to include in your graph?