In [1]:
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph

In [2]:
# Path to your local PDB file
pdb_path = "../data/raw/1f8a.pdb"

In [3]:
# Create a protein graph configuration
config = ProteinGraphConfig()

config.dict()

/scratch/local/51179627/ipykernel_1276145/2352068268.py:4: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  config.dict()


{'granularity': 'CA',
 'keep_hets': [],
 'insertions': True,
 'alt_locs': 'max_occupancy',
 'pdb_dir': None,
 'verbose': False,
 'exclude_waters': True,
 'deprotonate': False,
 'protein_df_processing_functions': None,
 'edge_construction_functions': [<function graphein.protein.edges.distance.add_peptide_bonds(G: 'nx.Graph') -> 'nx.Graph'>],
 'node_metadata_functions': [<function graphein.protein.features.nodes.amino_acid.meiler_embedding(n: str, d: Dict[str, Any], return_array: bool = False) -> Union[pandas.core.series.Series, numpy.ndarray]>],
 'edge_metadata_functions': None,
 'graph_metadata_functions': None,
 'get_contacts_config': None,
 'dssp_config': None}

In [4]:
# Construct the residue graph from the PDB file
residue_graph = construct_graph(config=config, path=pdb_path, )

# Display basic information about the graph
# print(f"Number of nodes: {len(residue_graph.nodes())}")
# print(f"Number of edges: {len(residue_graph.edges())}")

Output()

In [5]:
# print(residue_graph.nodes())

In [6]:
# print(residue_graph.edges())

In [7]:
# Extract the protein sequence
# sequence = "".join([data["residue_name"] for _, data in residue_graph.nodes(data=True)])

# print("Protein Sequence:", sequence)

In [8]:
# len(sequence)

In [9]:
chains = set(data["chain_id"] for _, data in residue_graph.nodes(data=True))
print("Chains in Graph:", chains)

Chains in Graph: {'C', 'B'}


In [10]:
nodes_residues = [data["residue_name"] for _, data in residue_graph.nodes(data=True)]
print(f"Number of Residues in Graph: {len(nodes_residues)}")


Number of Residues in Graph: 160


In [11]:
for _, data in residue_graph.nodes(data=True):
    print(data)
    break

{'chain_id': 'B', 'residue_name': 'GLY', 'residue_number': 1, 'atom_type': 'CA', 'element_symbol': 'C', 'coords': array([-0.294,  5.819, 75.824], dtype=float32), 'b_factor': 36.13999938964844, 'meiler': dim_1    0.00
dim_2    0.00
dim_3    0.00
dim_4    0.00
dim_5    6.07
dim_6    0.13
dim_7    0.15
Name: GLY, dtype: float64}


In [12]:
import networkx as nx
# Create a mapping from original nodes to integers
node_mapping = {node: idx for idx, node in enumerate(residue_graph.nodes())}

# Relabel the nodes in the graph
numeric_graph = nx.relabel_nodes(residue_graph, node_mapping)

# # Display the relabeled graph
# print(f"Nodes (numeric): {list(numeric_graph.nodes())}")
# print(f"Edges: {list(numeric_graph.edges())}")

# # Example: Access original attributes
# for node in numeric_graph.nodes():
#     original_node = list(node_mapping.keys())[list(node_mapping.values()).index(node)]
#     print(f"Original Node: {original_node}, Attributes: {residue_graph.nodes[original_node]}")

In [13]:
# from torch_geometric.utils import from_networkx
# # Step 2: Convert NetworkX graph to PyTorch Geometric Data
# data = from_networkx(residue_graph)

# # Check the PyTorch Geometric Data object
# print(data)

In [14]:
import torch
from torch_geometric.data import Data

# Assume 'residue_graph' is your NetworkX graph
# Create a mapping from node identifiers to numeric indices
node_mapping = {node: idx for idx, node in enumerate(residue_graph.nodes())}

# Remap edges using the node mapping
edge_list = [(node_mapping[u], node_mapping[v]) for u, v in residue_graph.edges()]

# Convert edge_list to a tensor
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

print(edge_index.shape)  # Should be [2, num_edges]

# Example Node Features (replace with your actual features)
num_nodes = len(node_mapping)  # Total number of nodes
node_features = torch.rand(num_nodes, 10)  # Example: 10 features per node

# Create PyTorch Geometric graph object
graph = Data(x=node_features, edge_index=edge_index)

print(graph)

torch.Size([2, 155])
Data(x=[160, 10], edge_index=[2, 155])


In [15]:
def validate_graph(graph):
    try:
        # Check if it's a PyTorch Geometric Data object
        if not isinstance(graph, Data):
            return False, f"File at is not a PyTorch Geometric Data object."

        # Check if `x` (node features) exists and is non-empty
        if not hasattr(graph, "x") or graph.x is None or graph.x.size(0) == 0:
            return False, f"Graph at has no valid node features."

        # Check if `edge_index` exists and is non-empty
        if not hasattr(graph, "edge_index") or graph.edge_index is None or graph.edge_index.size(1) == 0:
            return False, f"Graph at has no valid edges."

        # Check if `edge_index` indices are within the valid range
        if graph.edge_index.max() >= graph.x.size(0):
            return False, (
                f"Graph at has invalid edges. "
                f"Max index in edge_index: {graph.edge_index.max()}, Num nodes: {graph.x.size(0)}."
            )

        # Optional: Check if edge_index is symmetric for undirected graphs
        if not is_edge_index_symmetric(graph.edge_index):
            return False, f"Graph at has a non-symmetric edge_index for an undirected graph."

        # Optional: Check if `edge_index` contains duplicate edges
        if has_duplicate_edges(graph.edge_index):
            return False, f"Graph at contains duplicate edges."

        # If all checks pass
        return True, f"Graph at is valid."

    except Exception as e:
        return False, f"Error loading or validating graph at: {str(e)}"

def is_edge_index_symmetric(edge_index):
    # Flip the edge index to get reversed edges
    edge_index_flipped = edge_index.flip(0)

    # Combine original and flipped edges
    combined_edges = torch.cat([edge_index, edge_index_flipped], dim=1)

    # Remove duplicate edges
    unique_edges = torch.unique(combined_edges, dim=1)

    # Check if all edges have their reverses
    return unique_edges.size(1) == combined_edges.size(1)


def has_duplicate_edges(edge_index):
    """
    Checks if `edge_index` contains duplicate edges.
    
    Args:
        edge_index (Tensor): The edge_index tensor.
        
    Returns:
        bool: True if duplicates exist, False otherwise.
    """
    edges = edge_index.t().tolist()
    unique_edges = set(map(tuple, edges))
    return len(edges) != len(unique_edges)

In [16]:
is_valid, message = validate_graph(graph)
print(is_valid)
print(message)

True
Graph at is valid.


In [17]:
print(type(graph))

<class 'torch_geometric.data.data.Data'>


In [18]:
print("Graph is directed:", nx.is_directed(residue_graph))

Graph is directed: False


In [20]:
from functools import partial
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale
from graphein.protein.edges.distance import add_distance_threshold

config = ProteinGraphConfig(
        node_metadata_functions=[
            amino_acid_one_hot, meiler_embedding, expasy_protein_scale
        ],
        edge_construction_functions=[
            partial(add_distance_threshold, long_interaction_threshold=0)
        ]
    )

# Construct the graph using Graphein
graph = construct_graph(config=config, path=pdb_path)


Output()

In [27]:
for _, data in graph.nodes(data=True):
    print(data)
    break

{'chain_id': 'B', 'residue_name': 'GLY', 'residue_number': 1, 'atom_type': 'CA', 'element_symbol': 'C', 'coords': array([-0.294,  5.819, 75.824], dtype=float32), 'b_factor': 36.13999938964844, 'amino_acid_one_hot': array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'meiler': dim_1    0.00
dim_2    0.00
dim_3    0.00
dim_4    0.00
dim_5    6.07
dim_6    0.13
dim_7    0.15
Name: GLY, dtype: float64, 'expasy': pka_cooh_alpha              2.34
pka_nh3                     9.60
pka_rgroup                  7.00
isoelectric_points          6.06
molecularweight            75.00
                           ...  
antiparallelbeta_strand     0.56
parallelbeta_strand         0.79
a_a_composition             7.20
a_a_swiss_prot              7.07
relativemutability         49.00
Name: GLY, Length: 61, dtype: float64}


In [30]:
graph_types = ["onehot", "physchem", "expasy", "protbert", "prostt5"]

for graph_type in graph_types:# Generate node features based on graph type
    node_features = []
    for _, data in graph.nodes(data=True):
        if graph_type == "onehot":
            node_features.append(torch.tensor(data["amino_acid_one_hot"], dtype=torch.float))
        elif graph_type == "physchem":
            node_features.append(torch.tensor(data["meiler"].values, dtype=torch.float))
        elif graph_type == "expasy":
            node_features.append(torch.tensor(data["expasy"].values, dtype=torch.float))
        elif graph_type in ["protbert", "prostt5"] and embedding_fn:
            node_features = torch.tensor(embedding_fn(sequence_single_letter), dtype=torch.float)
        else:
            raise ValueError(f"Unknown graph type: {graph_type}")

    # Convert node features to a PyTorch tensor
    node_features = torch.stack(node_features)
    print(f'graph type: {graph_type}, node feature shape: {node_features.shape}')

graph type: onehot, node feature shape: torch.Size([160, 20])
graph type: physchem, node feature shape: torch.Size([160, 7])
graph type: expasy, node feature shape: torch.Size([160, 61])


NameError: name 'embedding_fn' is not defined

In [14]:
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from functools import partial
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot, meiler_embedding, expasy_protein_scale
from graphein.protein.edges.distance import add_distance_threshold

import torch
from torch_geometric.data import Data

config = ProteinGraphConfig(
        node_metadata_functions=[
            amino_acid_one_hot, meiler_embedding, expasy_protein_scale
        ],
        edge_construction_functions=[
            partial(add_distance_threshold, long_interaction_threshold=0)
        ]
    )

deafult_config = ProteinGraphConfig()
# Path to your local PDB file
pdb_path = "../../data/raw/2q0z.pdb"

# Construct the graph using Graphein
nx_graph = construct_graph(config=deafult_config, path=pdb_path)

# Map nodes to numeric indices
node_mapping = {node: idx for idx, node in enumerate(nx_graph.nodes())}
edge_list = [(node_mapping[u], node_mapping[v]) for u, v in nx_graph.edges()]
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
# # Assuming edge_index is [2, num_edges] and symmetric edges are required
# edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

# # Remove duplicate edges if they exist
# edge_index = torch.unique(edge_index, dim=1)

# Example Node Features (replace with your actual features)
num_nodes = len(node_mapping)  # Total number of nodes
node_features = torch.rand(num_nodes, 10)  # Example: 10 features per node

# Create PyTorch Geometric graph
graph = Data(x=node_features, edge_index=edge_index)

Output()

In [15]:
# Helper Functions
def validate_graph(graph, sequence_length=0):
    """
    Validates if a PyTorch Geometric graph is valid.
    """
    
    def is_edge_index_symmetric(edge_index):
        """
        Checks if `edge_index` contains non symmetric edges.
        """
        # Flip the edge index to get reversed edges
        edge_index_flipped = edge_index.flip(0)

        # Combine original and flipped edges
        combined_edges = torch.cat([edge_index, edge_index_flipped], dim=1)

        # Remove duplicate edges
        unique_edges = torch.unique(combined_edges, dim=1)

        # Check if all edges have their reverses
        return unique_edges.size(1) == edge_index.size(1)


    def has_duplicate_edges(edge_index):
        """
        Checks if `edge_index` contains duplicate edges.
        """
        edges = edge_index.t().tolist()
        unique_edges = set(map(tuple, edges))
        return len(edges) != len(unique_edges)

    try:
        if not isinstance(graph, Data):
            return False, "Graph is not a torch_geometric.data.data.Data object."

        if not hasattr(graph, "x") or graph.x is None or graph.x.size(0) == 0:
            return False, "Graph has no valid node features."

        if not hasattr(graph, "edge_index") or graph.edge_index is None or graph.edge_index.size(1) == 0:
            return False, "Graph has no valid edges."
        
        # # Validate node feature alignment with sequence length
        # if sequence_length != graph.x.size(0):
        #     return False, f"Mismatch between sequence length ({sequence_length}) and node features ({graph.x.size(0)})."

        if graph.edge_index.max() >= graph.x.size(0):
            return False, (
                f"Graph has invalid edges. "
                f"Max index in edge_index: {graph.edge_index.max()}, Num nodes: {graph.x.size(0)}."
            )

        if not is_edge_index_symmetric(graph.edge_index):
            return False, "Graph has a non-symmetric edge_index for an undirected graph."

        if has_duplicate_edges(graph.edge_index):
            return False, "Graph contains duplicate edges."

        # if graph.x.size(0) != sequence_length:
        #     return False, (
        #         f"Graph node features do not match sequence length. "
        #         f"Node features: {graph.x.size(0)}, Sequence length: {sequence_length}."
        #     )

        return True, "Graph is valid."

    except Exception as e:
        return False, f"Error validating graph: {str(e)}"

In [16]:
is_valid, message = validate_graph(graph)
print(is_valid)
print(message)

False
Graph has a non-symmetric edge_index for an undirected graph.


In [17]:
print(graph.edge_index.shape)

torch.Size([2, 273])


In [18]:
graph.edge_index[:,:10]

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])

In [19]:
edge_index_flipped = graph.edge_index.flip(0)
edge_index_flipped[:,:5]

tensor([[1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4]])

In [20]:
# Combine original and flipped edges
combined_edges = torch.cat([edge_index, edge_index_flipped], dim=1)
combined_edges[:,:5]

tensor([[0, 1, 2, 3, 4],
        [1, 2, 3, 4, 5]])

In [21]:
combined_edges[:,273:278]

tensor([[1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4]])

In [22]:
# Remove duplicate edges
unique_edges = torch.unique(combined_edges, dim=1)
unique_edges.shape

torch.Size([2, 546])

In [11]:
unique_edges[:,:10]

tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])

In [12]:
combined_edges.shape

torch.Size([2, 546])

In [13]:
num_nodes

282