In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from multiprocessing import Pool, cpu_count
from joblib import Parallel, delayed  # Better parallelism

# Path to the directory containing PDB files
pdb_dir = r"D:\P2Rank_GNN_Dataset\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

def validate_and_fix_mol(mol):
    """
    Removes atoms with invalid valences from the molecule.

    Args:
        mol (rdkit.Chem.Mol): The input molecule.

    Returns:
        rdkit.Chem.Mol: A molecule with problematic atoms removed.
    """
    to_remove = [atom.GetIdx() for atom in mol.GetAtoms()
                 if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum())]

    if to_remove:
        editable_mol = Chem.EditableMol(mol)
        for idx in sorted(to_remove, reverse=True):
            editable_mol.RemoveAtom(idx)
        return editable_mol.GetMol()
    return mol

def extract_sas_points(pdb_file):
    """
    Extracts 3D coordinates for SAS points from a PDB file.

    Args:
        pdb_file (str): Path to the PDB file.

    Returns:
        tuple: (pdb_file, torch.Tensor) or None if failed.
    """
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        if mol is None:
            raise RuntimeError(f"Failed to load PDB file: {pdb_file}")

        # Validate and fix problematic atoms
        mol = validate_and_fix_mol(mol)

        # Attempt sanitization)
        try:
            Chem.SanitizeMol(mol)
        except Exception as e:
            print(f"Sanitization failed for {pdb_file}: {e}")
            return None

        # Extract 3D coordinates directly (skip 2D computation)
        conf = mol.GetConformer()
        sas_points = [[pos.x, pos.y, pos.z] for pos in (conf.GetAtomPosition(atom.GetIdx()) for atom in mol.GetAtoms())]

        return pdb_file, torch.tensor(sas_points, dtype=torch.float)

    except Exception as e:
        print(f"Error processing {pdb_file}: {e}")
        return None

def process_pdb_files_parallel(pdb_files, num_workers=None):
    """
    Processes PDB files in parallel using joblib for better parallelism.

    Args:
        pdb_files (list): List of PDB file paths.
        num_workers (int, optional): Number of processes to use. Defaults to max available.

    Returns:
        list: List of tensors containing node features.
    """
    if num_workers is None:
        num_workers = min(32, cpu_count() - 2)  # Use a maximum of 32 workers to avoid Windows limitations

    print(f"Using {num_workers} worker processes.")

    # Use joblib for parallel processing
    results = Parallel(n_jobs=num_workers)(delayed(extract_sas_points)(file) for file in pdb_files)

    # Remove failed cases
    all_node_features = [res[1] for res in results if res is not None]

    return all_node_features

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Run parallel processing
all_node_features = process_pdb_files_parallel(pdb_files)

# Move tensors to GPU if available
all_node_features = [tensor.to(device) for tensor in all_node_features]

# Print summary
print(f"Processed {len(all_node_features)} out of {len(pdb_files)} files successfully.")

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

ImportError: numpy.core._multiarray_umath failed to import

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

RuntimeError: module compiled against API version 0x10 but this version of numpy is 0xf . Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem .

Using device: cuda
Using 32 worker processes.
Processed 251 out of 251 files successfully.


In [2]:
import numpy as np
import torch
from scipy.spatial.distance import cdist

def construct_edges(sas_points, distance_threshold=6.0):
    """
    Constructs edges based on spatial proximity between SAS points.

    Args:
        sas_points (np.ndarray): Array of 3D coordinates of SAS points.
        distance_threshold (float): Maximum distance to consider an edge.

    Returns:
        np.ndarray: Array of edge pairs.
    """
    if len(sas_points) == 0:
        return np.array([])
    
    # Compute pairwise distance matrix
    dist_matrix = cdist(sas_points, sas_points)
    
    # Identify edges where the distance is below the threshold
    edge_indices = np.where(dist_matrix < distance_threshold)
    
    # Filter out duplicate edges (i < j)
    edges = np.array([[i, j] for i, j in zip(*edge_indices) if i < j])
    
    return edges

# Example usage
# Assuming `all_node_features` is available from Code 1
all_edges = []

for sas_points in all_node_features:
    sas_np = sas_points.cpu().numpy()  # Convert tensor to NumPy array if needed
    edges = construct_edges(sas_np)
    all_edges.append(torch.tensor(edges.T, dtype=torch.long))  # Convert to tensor

# Print summary
print(f"Constructed edges for {len(all_edges)} graphs.")

Constructed edges for 251 graphs.


In [3]:
# Import required libraries
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from scipy.spatial.distance import euclidean
from torch_geometric.data import Data

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))


def validate_and_fix_mol(mol):
    """
    Removes atoms with invalid valences from the molecule.
    """
    if mol is None:
        return None
    
    to_remove = []
    for atom in mol.GetAtoms():
        if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
            to_remove.append(atom.GetIdx())

    if to_remove:
        editable_mol = Chem.EditableMol(mol)
        for idx in sorted(to_remove, reverse=True):
            editable_mol.RemoveAtom(idx)
        return editable_mol.GetMol()
    return mol


def extract_sas_points(pdb_file):
    """
    Extracts 3D coordinates for SAS points from a PDB file.
    """
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        if mol is None:
            raise RuntimeError(f"Failed to load PDB file: {pdb_file}")

        # Validate and fix problematic atoms
        mol = validate_and_fix_mol(mol)
        if mol is None:
            raise RuntimeError(f"Invalid molecule structure in {pdb_file}")

        # Attempt sanitization
        Chem.SanitizeMol(mol)

        # Extract 3D coordinates
        conf = mol.GetConformer()
        sas_points = np.array([[pos.x, pos.y, pos.z] for pos in (conf.GetAtomPosition(atom.GetIdx()) for atom in mol.GetAtoms())])

        return sas_points if sas_points.size > 0 else None

    except Exception as e:
        print(f"Error processing PDB file {pdb_file}: {e}")
        return None


def construct_edges(sas_points, distance_threshold=6.0):
    """
    Constructs edges based on spatial proximity between SAS points within a threshold.
    """
    edges = []
    num_points = len(sas_points)
    for i in range(num_points):
        for j in range(i + 1, num_points):
            if euclidean(sas_points[i], sas_points[j]) < distance_threshold:
                edges.append([i, j])
    
    return np.array(edges) if edges else np.empty((0, 2))


# Process all PDB files
all_node_features = []
processed_files = 0

for pdb_file in pdb_files:
    sas_points = extract_sas_points(pdb_file)
    if sas_points is not None:
        # Convert SAS points to a torch tensor and store them
        node_features = torch.tensor(sas_points, dtype=torch.float)
        all_node_features.append(node_features)
        processed_files += 1

# Print summary of processed files
print(f"Processed {processed_files} out of {len(pdb_files)} files successfully.")

# Process the first valid PDB file as an example
if all_node_features:
    sas_points = all_node_features[0].numpy()  # Extract SAS points from the first molecule

    # Construct edges based on proximity
    edges = construct_edges(sas_points)

    # Convert edge indices to torch tensor
    edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)

    # Compute edge features (Euclidean distance)
    edge_features = torch.tensor([[euclidean(sas_points[i], sas_points[j])] for i, j in edges], dtype=torch.float).view(-1, 1) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

    # Example labels (dummy binding affinity or classification label)
    labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  # Dummy labels for nodes

    # Create the graph object using PyTorch Geometric
    data = Data(x=torch.tensor(sas_points, dtype=torch.float),
                edge_index=edge_index,
                edge_attr=edge_features,
                y=labels)

    # Print the graph details
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
else:
    print("No valid node features were extracted.")


Processed 251 out of 251 files successfully.
Number of nodes: 1011
Number of edges: 19479


In [4]:
from torch_geometric.data import Data
import torch
from scipy.spatial.distance import cdist

# Example: Define edge features (distances between connected nodes)
edge_features = torch.tensor(
    np.linalg.norm(sas_points[edges[:, 0]] - sas_points[edges[:, 1]], axis=1).reshape(-1, 1),
    dtype=torch.float
)

# Labels for the graph (example binding affinity or classification)
# Here, using a single scalar as the label for the entire graph
labels = torch.tensor([0.5], dtype=torch.float) 

# Create the graph object
data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=labels)

# Print the graph details
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')


Number of nodes: 2026
Number of edges: 19479


In [5]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from scipy.spatial.distance import euclidean
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

def validate_and_fix_mol(mol):
    to_remove = [atom.GetIdx() for atom in mol.GetAtoms()
                 if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum())]

    if to_remove:
        editable_mol = Chem.EditableMol(mol)
        for idx in sorted(to_remove, reverse=True):
            editable_mol.RemoveAtom(idx)
        return editable_mol.GetMol()
    return mol

def extract_sas_points(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        if mol is None:
            raise RuntimeError(f"Failed to load PDB file: {pdb_file}")

        # Validate and fix problematic atoms
        mol = validate_and_fix_mol(mol)
        Chem.SanitizeMol(mol)

        conf = mol.GetConformer()
        sas_points = np.array([[pos.x, pos.y, pos.z] for pos in (conf.GetAtomPosition(atom.GetIdx()) for atom in mol.GetAtoms())])

        return sas_points if sas_points.size > 0 else None
    except Exception as e:
        print(f"Error processing PDB file {pdb_file}: {e}")
        return None

def construct_edges(sas_points, distance_threshold=6.0):
    edges = []
    num_points = len(sas_points)
    for i in range(num_points):
        for j in range(i + 1, num_points):
            if euclidean(sas_points[i], sas_points[j]) < distance_threshold:
                edges.append([i, j])
    
    return np.array(edges) if edges else np.empty((0, 2))

class ProteinDataset(Dataset):
    def __init__(self, pdb_files, transform=None, pre_transform=None):
        super().__init__('.', transform, pre_transform)
        self.pdb_files = pdb_files

    def len(self):
        return len(self.pdb_files)

    def get(self, idx):
        pdb_file = self.pdb_files[idx]
        sas_points = extract_sas_points(pdb_file)
        if sas_points is None:
            return None

        edges = construct_edges(sas_points)
        edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)
        edge_features = torch.tensor([[euclidean(sas_points[i], sas_points[j])] for i, j in edges], dtype=torch.float).view(-1, 1) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

        node_features = torch.tensor(sas_points, dtype=torch.float)
        labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  

        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=labels)
        return data

# Create the dataset and DataLoader
protein_dataset = ProteinDataset(pdb_files)
data_loader = DataLoader(protein_dataset, batch_size=1, shuffle=True)

# Print summary of the dataset
print(f"Total number of graphs: {len(protein_dataset)}")

# Example of accessing a single graph data
for data in data_loader:
    print(data)
    break


Total number of graphs: 251
DataBatch(x=[2536, 3], edge_index=[2, 50553], edge_attr=[50553, 1], y=[2536], batch=[2536], ptr=[2])


In [6]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from scipy.spatial import KDTree
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from joblib import Parallel, delayed

# Detect CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

# ======================== Step 1: Validate and Fix Molecule ========================
def validate_and_fix_mol(mol):
    if mol is None:
        return None
    try:
        for atom in mol.GetAtoms():
            if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
                return None  # Ignore molecules with incorrect valence
        Chem.SanitizeMol(mol)
        return mol
    except:
        return None

# ======================== Step 2: Extract SAS Points Efficiently ========================
def extract_sas_points(pdb_file):
    mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
    mol = validate_and_fix_mol(mol)
    if mol is None:
        return None

    conf = mol.GetConformer()
    sas_points = np.array([[conf.GetAtomPosition(atom.GetIdx()).x,
                            conf.GetAtomPosition(atom.GetIdx()).y,
                            conf.GetAtomPosition(atom.GetIdx()).z] for atom in mol.GetAtoms()])
    
    return sas_points if sas_points.size > 0 else None

# ======================== Step 3: Fast Edge Construction with KDTree ========================
def construct_edges(sas_points, distance_threshold=6.0):
    if len(sas_points) == 0:
        return np.empty((0, 2), dtype=np.int64)

    tree = KDTree(sas_points)
    pairs = tree.query_pairs(distance_threshold)
    edges = np.array(list(pairs))

    return edges if edges.size > 0 else np.empty((0, 2), dtype=np.int64)

# ======================== Step 5: Compute Local Graph Features Efficiently ========================
def compute_local_graph_features(edges, num_nodes):
    degree = np.zeros(num_nodes, dtype=np.float32)
    for i, j in edges:
        degree[i] += 1
        degree[j] += 1
    return torch.tensor(degree).view(-1, 1)  # Degree as node feature

# ======================== Step 6: Parallel Graph Construction ========================
def process_pdb_file(pdb_file):
    sas_points = extract_sas_points(pdb_file)
    if sas_points is None:
        return None

    edges = construct_edges(sas_points)
    edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)
    edge_features = torch.tensor([[np.linalg.norm(sas_points[i] - sas_points[j])] for i, j in edges], dtype=torch.float) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

    # Node features: SAS points + Local graph features
    node_features = torch.tensor(sas_points, dtype=torch.float)
    local_graph_features = compute_local_graph_features(edges, len(sas_points))
    
    # Combine node features (coordinates + local properties)
    combined_node_features = torch.cat([node_features, local_graph_features], dim=1)

    labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  

    data = Data(
        x=combined_node_features.to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_features.to(device),
        y=labels.to(device)
    )

    return data

# ======================== Step 7: Optimized Dataset Class ========================
class ProteinGraphDataset(InMemoryDataset):
    def __init__(self, root, pdb_files, transform=None, pre_transform=None):
        self.pdb_files = pdb_files
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()

    @property
    def processed_file_names(self):
        return ["protein_graphs.pt"]

    def process(self):
        # Parallel Processing using Joblib
        data_list = Parallel(n_jobs=8)(delayed(process_pdb_file)(pdb) for pdb in self.pdb_files)

        # Remove None values (failed PDB files)
        data_list = [d for d in data_list if d is not None]
        print(f"Processed {len(data_list)} graphs out of {len(self.pdb_files)} PDB files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        return data, slices

    def get(self, idx):
        return super().get(idx)  # Corrected to use the parent class method

# ======================== Step 8: Run Dataset Creation & DataLoader ========================
protein_graph_dataset = ProteinGraphDataset(root='protein_data', pdb_files=pdb_files)
dataset_loader = DataLoader(protein_graph_dataset, batch_size=1, shuffle=True)

print(f"Total number of graphs: {len(protein_graph_dataset)}")

# Example of accessing a single graph data
for data in dataset_loader:
    print(data)
    break
    

Total number of graphs: 85
DataBatch(x=[1758, 4], edge_index=[2, 34110], edge_attr=[34110, 1], y=[1758], batch=[1758], ptr=[2])


  self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()


In [7]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from scipy.spatial import KDTree
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from joblib import Parallel, delayed

# Detect CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

# ======================== Step 1: Validate and Fix Molecule ========================
def validate_and_fix_mol(mol):
    if mol is None:
        return None
    try:
        for atom in mol.GetAtoms():
            if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
                return None  # Ignore molecules with incorrect valence
        Chem.SanitizeMol(mol)
        return mol
    except Exception as e:
        print(f"Error sanitizing molecule: {e}")
        return None

# ======================== Step 2: Extract SAS Points Efficiently ========================
def extract_sas_points(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        mol = validate_and_fix_mol(mol)
        if mol is None:
            return None

        conf = mol.GetConformer()
        sas_points = np.array([[conf.GetAtomPosition(atom.GetIdx()).x,
                                conf.GetAtomPosition(atom.GetIdx()).y,
                                conf.GetAtomPosition(atom.GetIdx()).z] for atom in mol.GetAtoms()])
        
        return sas_points if sas_points.size > 0 else None
    except Exception as e:
        print(f"Error processing PDB file {pdb_file}: {e}")
        return None

# ======================== Step 3: Fast Edge Construction with KDTree ========================
def construct_edges(sas_points, distance_threshold=6.0):
    if len(sas_points) == 0:
        return np.empty((0, 2), dtype=np.int64)

    tree = KDTree(sas_points)
    pairs = tree.query_pairs(distance_threshold)
    edges = np.array(list(pairs))

    return edges if edges.size > 0 else np.empty((0, 2), dtype=np.int64)

# ======================== Step 5: Compute Local Graph Features Efficiently ========================
def compute_local_graph_features(edges, num_nodes):
    degree = np.zeros(num_nodes, dtype=np.float32)
    for i, j in edges:
        degree[i] += 1
        degree[j] += 1
    return torch.tensor(degree).view(-1, 1)  # Degree as node feature

# ======================== Step 6: Parallel Graph Construction ========================
def process_pdb_file(pdb_file):
    sas_points = extract_sas_points(pdb_file)
    if sas_points is None:
        return None

    edges = construct_edges(sas_points)
    edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)
    edge_features = torch.tensor([[np.linalg.norm(sas_points[i] - sas_points[j])] for i, j in edges], dtype=torch.float) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

    # Node features: SAS points + Local graph features
    node_features = torch.tensor(sas_points, dtype=torch.float)
    local_graph_features = compute_local_graph_features(edges, len(sas_points))
    
    # Combine node features (coordinates + local properties)
    combined_node_features = torch.cat([node_features, local_graph_features], dim=1)

    labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  

    data = Data(
        x=combined_node_features.to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_features.to(device),
        y=labels.to(device)
    )

    return data

# ======================== Step 7: Optimized Dataset Class ========================
class ProteinGraphDataset(InMemoryDataset):
    def __init__(self, root, pdb_files, transform=None, pre_transform=None):
        self.pdb_files = pdb_files
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()

    @property
    def processed_file_names(self):
        return ["protein_graphs.pt"]

    def process(self):
        # Parallel Processing using Joblib
        data_list = Parallel(n_jobs=8)(delayed(process_pdb_file)(pdb) for pdb in self.pdb_files)

        # Remove None values (failed PDB files)
        data_list = [d for d in data_list if d is not None]
        print(f"Processed {len(data_list)} graphs out of {len(self.pdb_files)} PDB files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        return data, slices

    def get(self, idx):
        return super().get(idx)  # Corrected to use the parent class method

# ======================== Step 8: Run Dataset Creation & DataLoader ========================
protein_graph_dataset = ProteinGraphDataset(root='protein_data', pdb_files=pdb_files)
dataset_loader = DataLoader(protein_graph_dataset, batch_size=1, shuffle=True)

print(f"Total number of graphs: {len(protein_graph_dataset)}")

# Example of accessing a single graph data
for data in dataset_loader:
    print(data)
    break

Total number of graphs: 85
DataBatch(x=[1242, 4], edge_index=[2, 23942], edge_attr=[23942, 1], y=[1242], batch=[1242], ptr=[2])


  self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()


In [8]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import torch
import os
import glob
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from joblib import Parallel, delayed
from scipy.spatial import KDTree

# Detect CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

def validate_and_fix_mol(mol):
    if mol is None:
        return None
    try:
        for atom in mol.GetAtoms():
            if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
                return None  # Ignore molecules with incorrect valence
        Chem.SanitizeMol(mol)
        return mol
    except Exception as e:
        print(f"Error sanitizing molecule: {e}")
        return None

def extract_sas_points(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        mol = validate_and_fix_mol(mol)
        if mol is None:
            return None

        # Compute Gasteiger partial charges
        Chem.rdPartialCharges.ComputeGasteigerCharges(mol)

        conf = mol.GetConformer()
        sas_points = np.array([[conf.GetAtomPosition(atom.GetIdx()).x,
                                conf.GetAtomPosition(atom.GetIdx()).y,
                                conf.GetAtomPosition(atom.GetIdx()).z] for atom in mol.GetAtoms()])
        
        return mol, sas_points if sas_points.size > 0 else None
    except Exception as e:
        print(f"Error processing PDB file {pdb_file}: {e}")
        return None, None

def construct_edges(sas_points, distance_threshold=6.0):
    if len(sas_points) == 0:
        return np.empty((0, 2), dtype=np.int64)

    tree = KDTree(sas_points)
    pairs = tree.query_pairs(distance_threshold)
    edges = np.array(list(pairs))

    return edges if edges.size > 0 else np.empty((0, 2), dtype=np.int64)

def compute_atomic_features(mol, sas_points, radius=6.0):
    atomic_features = []
    for sas_point in sas_points:
        nearby_atoms = []
        for atom in mol.GetAtoms():
            pos = mol.GetConformer().GetAtomPosition(atom.GetIdx())
            dist = np.linalg.norm(np.array([pos.x, pos.y, pos.z]) - sas_point)
            if dist <= radius:
                atomic_num = atom.GetAtomicNum()
                aromatic = int(atom.GetIsAromatic())
                hybridization = int(atom.GetHybridization())
                degree = atom.GetDegree()
                partial_charge = atom.GetDoubleProp('_GasteigerCharge')
                nearby_atoms.append([atomic_num, aromatic, hybridization, degree, partial_charge])
        
        if nearby_atoms:
            atomic_features.append(np.mean(nearby_atoms, axis=0))
        else:
            atomic_features.append(np.zeros(5))
    
    return torch.tensor(atomic_features, dtype=torch.float)

def process_pdb_file(pdb_file):
    mol, sas_points = extract_sas_points(pdb_file)
    if sas_points is None:
        return None

    edges = construct_edges(sas_points)
    edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)
    edge_features = torch.tensor([[np.linalg.norm(sas_points[i] - sas_points[j])] for i, j in edges], dtype=torch.float) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

    # Node features: Atomic features
    node_features = compute_atomic_features(mol, sas_points)
    
    labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  

    data = Data(
        x=node_features.to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_features.to(device),
        y=labels.to(device)
    )

    return data

class ProteinGraphDataset(InMemoryDataset):
    def __init__(self, root, pdb_files, transform=None, pre_transform=None):
        self.pdb_files = pdb_files
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()

    @property
    def processed_file_names(self):
        return ["protein_graphs.pt"]

    def process(self):
        data_list = Parallel(n_jobs=8)(delayed(process_pdb_file)(pdb) for pdb in self.pdb_files)
        data_list = [d for d in data_list if d is not None]
        print(f"Processed {len(data_list)} graphs out of {len(self.pdb_files)} PDB files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        return data, slices

    def get(self, idx):
        return super().get(idx)

protein_graph_dataset = ProteinGraphDataset(root='protein_data', pdb_files=pdb_files)
dataset_loader = DataLoader(protein_graph_dataset, batch_size=1, shuffle=True)

print(f"Total number of graphs: {len(protein_graph_dataset)}")

for data in dataset_loader:
    print(data)
    break


Total number of graphs: 85
DataBatch(x=[1782, 4], edge_index=[2, 35004], edge_attr=[35004, 1], y=[1782], batch=[1782], ptr=[2])


  self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()


In [9]:
import networkx as nx
import torch
from torch_geometric.data import DataLoader

def compute_graph_features(data):
    # Convert PyTorch Geometric Data to NetworkX graph
    edge_index = data.edge_index.cpu().numpy()
    G = nx.Graph()
    G.add_edges_from(edge_index.T)

    # Compute node degrees
    degrees = dict(G.degree())

    # Compute clustering coefficients
    clustering_coeffs = nx.clustering(G)

    return {
        "degree": degrees,
        "clustering_coefficient": clustering_coeffs
    }

# Example usage with DataLoader
for data in dataset_loader:
    features = compute_graph_features(data)
    print(f"Degrees: {features['degree']}")
    print(f"Clustering Coefficients: {features['clustering_coefficient']}")
    break


Degrees: {266: 53, 268: 50, 400: 38, 1022: 52, 928: 56, 1141: 55, 298: 42, 745: 47, 1447: 29, 1706: 57, 490: 49, 915: 43, 1620: 33, 1634: 39, 1257: 48, 1582: 46, 4: 38, 1961: 48, 1220: 47, 1700: 46, 1256: 50, 1357: 42, 129: 39, 155: 24, 561: 46, 566: 28, 1421: 39, 1424: 37, 44: 34, 47: 24, 904: 52, 905: 54, 1115: 41, 1125: 36, 1975: 47, 1983: 52, 619: 48, 921: 62, 273: 44, 284: 45, 1138: 56, 1486: 57, 1133: 44, 1142: 47, 1458: 30, 1464: 29, 424: 38, 453: 29, 1229: 50, 1604: 56, 1671: 39, 1757: 42, 1344: 49, 1362: 45, 1476: 12, 1481: 49, 1687: 38, 1701: 46, 99: 36, 104: 54, 442: 22, 443: 16, 1488: 48, 1491: 49, 254: 23, 392: 40, 419: 54, 459: 49, 1339: 48, 1368: 24, 971: 33, 972: 43, 671: 48, 680: 53, 1183: 52, 1794: 33, 1375: 42, 1964: 43, 1543: 34, 1548: 31, 166: 29, 171: 33, 245: 52, 464: 46, 1026: 31, 1029: 31, 1886: 46, 1887: 45, 1977: 52, 2005: 45, 449: 34, 614: 51, 2024: 34, 509: 57, 510: 51, 593: 47, 1147: 43, 1328: 47, 1367: 27, 1352: 47, 1953: 45, 1309: 46, 1317: 53, 1231: 52,

In [11]:
import torch
import torch_geometric
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch_geometric.utils import from_networkx
import networkx as nx
import numpy as np
import os
import glob
from scipy.spatial import KDTree
from rdkit import Chem
from rdkit.Chem import AllChem
from joblib import Parallel, delayed

# Detect CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Path to the directory containing PDB files
pdb_dir = "D:\\P2Rank_GNN_Dataset\\chen11"

# Get a list of all PDB files in the directory
pdb_files = glob.glob(os.path.join(pdb_dir, "*.pdb"))

def validate_and_fix_mol(mol):
    if mol is None:
        return None
    try:
        for atom in mol.GetAtoms():
            if atom.GetExplicitValence() > Chem.GetPeriodicTable().GetDefaultValence(atom.GetAtomicNum()):
                return None  # Ignore molecules with incorrect valence
        Chem.SanitizeMol(mol)
        return mol
    except Exception as e:
        print(f"Error sanitizing molecule: {e}")
        return None

def extract_sas_points(pdb_file):
    try:
        mol = Chem.MolFromPDBFile(pdb_file, removeHs=False, sanitize=False)
        mol = validate_and_fix_mol(mol)
        if mol is None:
            return None

        conf = mol.GetConformer()
        sas_points = np.array([[conf.GetAtomPosition(atom.GetIdx()).x,
                                conf.GetAtomPosition(atom.GetIdx()).y,
                                conf.GetAtomPosition(atom.GetIdx()).z] for atom in mol.GetAtoms()])
        
        return sas_points if sas_points.size > 0 else None
    except Exception as e:
        print(f"Error processing PDB file {pdb_file}: {e}")
        return None

def construct_edges(sas_points, distance_threshold=6.0):
    if len(sas_points) == 0:
        return np.empty((0, 2), dtype=np.int64)

    tree = KDTree(sas_points)
    pairs = tree.query_pairs(distance_threshold)
    edges = np.array(list(pairs))

    return edges if edges.size > 0 else np.empty((0, 2), dtype=np.int64)

def compute_local_graph_features(edges, num_nodes):
    degree = np.zeros(num_nodes, dtype=np.float32)
    for i, j in edges:
        degree[i] += 1
        degree[j] += 1
    return torch.tensor(degree).view(-1, 1)  # Degree as node feature

def process_pdb_file(pdb_file):
    sas_points = extract_sas_points(pdb_file)
    if sas_points is None:
        return None

    edges = construct_edges(sas_points)
    edge_index = torch.tensor(edges.T, dtype=torch.long) if edges.size > 0 else torch.empty((2, 0), dtype=torch.long)
    edge_features = torch.tensor([[np.linalg.norm(sas_points[i] - sas_points[j])] for i, j in edges], dtype=torch.float) if edges.size > 0 else torch.empty((0, 1), dtype=torch.float)

    node_features = torch.tensor(sas_points, dtype=torch.float)
    local_graph_features = compute_local_graph_features(edges, len(sas_points))
    combined_node_features = torch.cat([node_features, local_graph_features], dim=1)

    labels = torch.full((len(sas_points),), 0.5, dtype=torch.float)  

    data = Data(
        x=combined_node_features.to(device),
        edge_index=edge_index.to(device),
        edge_attr=edge_features.to(device),
        y=labels.to(device)
    )

    return data

class ProteinGraphDataset(InMemoryDataset):
    def __init__(self, root, pdb_files, transform=None, pre_transform=None):
        self.pdb_files = pdb_files
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()

    @property
    def processed_file_names(self):
        return ["protein_graphs.pt"]

    def process(self):
        data_list = Parallel(n_jobs=8)(delayed(process_pdb_file)(pdb) for pdb in self.pdb_files)
        data_list = [d for d in data_list if d is not None]
        print(f"Processed {len(data_list)} graphs out of {len(self.pdb_files)} PDB files.")

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        return data, slices

    def get(self, idx):
        return super().get(idx)

protein_graph_dataset = ProteinGraphDataset(root='protein_data', pdb_files=pdb_files)
dataset_loader = DataLoader(protein_graph_dataset, batch_size=1, shuffle=True)

print(f"Total number of graphs: {len(protein_graph_dataset)}")

for data in dataset_loader:
    print(data)
    break


Total number of graphs: 85
DataBatch(x=[2216, 4], edge_index=[2, 42974], edge_attr=[42974, 1], y=[2216], batch=[2216], ptr=[2])


  self.data, self.slices = torch.load(self.processed_paths[0]) if os.path.exists(self.processed_paths[0]) else self.process()


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
import torch.optim as optim

# Define the GCN Encoder
class GCNEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, latent_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv_mu = GCNConv(hidden_dim, latent_dim)  # Mean of latent space
        self.conv_logvar = GCNConv(hidden_dim, latent_dim)  # Log variance

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        mu = self.conv_mu(x, edge_index)
        logvar = self.conv_logvar(x, edge_index)
        return mu, logvar

# Define the Variational Autoencoder (VAE)
class VAE(nn.Module):
    def __init__(self, in_dim, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = GCNEncoder(in_dim, hidden_dim, latent_dim)
        self.decoder = GCNConv(latent_dim, in_dim)  # Reconstruct original input

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar / 2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, edge_index):
        mu, logvar = self.encoder(x, edge_index)  # Encode input
        logvar = torch.clamp(logvar, min=-10, max=10)  # Stabilize variance
        z = self.reparameterize(mu, logvar)  # Sample latent vector
        recon_x = self.decoder(z, edge_index)  # Reconstruct input
        return recon_x, mu, logvar

# VAE Loss Function
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')  # Reconstruction Loss
    kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())  # KL Divergence
    return recon_loss + kl_div

# Training Function
def train_vae(model, dataloader, optimizer, device, epochs=50):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in dataloader:
            data = data.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(data.x, data.edge_index)
            loss = vae_loss(recon_x, data.x, mu, logvar)
            
            if torch.isnan(loss):  # Check for NaN loss
                print(f"NaN encountered at epoch {epoch+1}, skipping update...")
                continue  # Skip update to prevent corrupting weights
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

# Ensure dataset is valid (avoid NaNs)
def check_dataset(dataset):
    for i, data in enumerate(dataset):
        if torch.isnan(data.x).any():
            print(f"NaN detected in node features of graph {i}")
        if torch.isnan(data.edge_index).any():
            print(f"NaN detected in edge index of graph {i}")

# Initialize Model and Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check dataset before training
check_dataset(protein_graph_dataset)

model = VAE(in_dim=protein_graph_dataset[0].x.shape[1], hidden_dim=64, latent_dim=32).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Reduced LR for stability

# Use the processed protein dataset
dataloader = DataLoader(protein_graph_dataset, batch_size=16, shuffle=True)

# Train the Model
train_vae(model, dataloader, optimizer, device, epochs=50)


Epoch 1/50, Loss: 2898.3482
Epoch 2/50, Loss: 2305.1294
Epoch 3/50, Loss: 1851.2688
Epoch 4/50, Loss: 1500.8197
Epoch 5/50, Loss: 1295.1101
Epoch 6/50, Loss: 997.6453
Epoch 7/50, Loss: 838.5537
Epoch 8/50, Loss: 750.2469
Epoch 9/50, Loss: 626.2530
Epoch 10/50, Loss: 562.7170
Epoch 11/50, Loss: 502.6624
Epoch 12/50, Loss: 497.6691
Epoch 13/50, Loss: 400.3780
Epoch 14/50, Loss: 374.6241
Epoch 15/50, Loss: 357.9088
Epoch 16/50, Loss: 331.5323
Epoch 17/50, Loss: 315.0074
Epoch 18/50, Loss: 326.9326
Epoch 19/50, Loss: 292.6086
Epoch 20/50, Loss: 274.7918
Epoch 21/50, Loss: 269.6394
Epoch 22/50, Loss: 259.1256
Epoch 23/50, Loss: 245.0397
Epoch 24/50, Loss: 231.5599
Epoch 25/50, Loss: 236.5069
Epoch 26/50, Loss: 231.4348
Epoch 27/50, Loss: 235.1608
Epoch 28/50, Loss: 214.2791
Epoch 29/50, Loss: 218.0996
Epoch 30/50, Loss: 210.0375
Epoch 31/50, Loss: 208.6964
Epoch 32/50, Loss: 199.3753
Epoch 33/50, Loss: 202.2251
Epoch 34/50, Loss: 208.1266
Epoch 35/50, Loss: 196.6840
Epoch 36/50, Loss: 187.9

In [21]:
import torch
import random
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import DataLoader  # Use PyG DataLoader

# Use the processed protein dataset
dataloader = DataLoader(protein_graph_dataset, batch_size=16, shuffle=True)

def get_link_labels(edge_index, num_nodes):
    """
    Generate positive and negative edges for link prediction.
    """
    # Positive edges (existing edges)
    pos_edge_index = edge_index

    # Negative edges (randomly sampled non-existent edges)
    neg_edge_index = negative_sampling(
        edge_index=edge_index, num_nodes=num_nodes, num_neg_samples=edge_index.size(1)
    )

    # Labels: 1 for positive edges, 0 for negative edges
    edge_labels = torch.cat([
        torch.ones(pos_edge_index.size(1)),  # Positive labels
        torch.zeros(neg_edge_index.size(1))  # Negative labels
    ]).to(torch.float)

    # Combine positive and negative edges
    combined_edges = torch.cat([pos_edge_index, neg_edge_index], dim=1)

    return combined_edges, edge_labels

# Process the protein dataset batch-wise
for batch in dataloader:
    batch = batch.to('cuda' if torch.cuda.is_available() else 'cpu')  # Move to GPU if available
    combined_edges, edge_labels = get_link_labels(batch.edge_index, batch.num_nodes)


In [22]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class LinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, 1)  # Concatenate node pairs

    def forward(self, x, edge_index, combined_edges):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

        # Extract node embeddings from batched graph
        src = x[combined_edges[0]]  # Source nodes
        dst = x[combined_edges[1]]  # Destination nodes
        edge_features = torch.cat([src, dst], dim=1)  # Concatenate features

        return torch.sigmoid(self.fc(edge_features)).squeeze()


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader

# Function to get positive & negative edges for training
def get_link_labels(edge_index, num_nodes):
    device = edge_index.device
    num_edges = edge_index.size(1)

    # Positive edges (existing edges)
    pos_edges = edge_index

    # Generate negative edges (random pairs that are not connected)
    neg_edges = torch.randint(0, num_nodes, (2, num_edges), device=device)
    
    # Concatenate positive and negative edges
    combined_edges = torch.cat([pos_edges, neg_edges], dim=1)
    
    # Labels: 1 for positive edges, 0 for negative edges
    edge_labels = torch.cat([torch.ones(num_edges), torch.zeros(num_edges)], dim=0).to(device)

    return combined_edges, edge_labels

# Link Predictor Model
class LinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, 1)  # Concat source & target embeddings

    def forward(self, x, edge_index, combined_edges):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

        # Extract node embeddings
        src = x[combined_edges[0]]
        dst = x[combined_edges[1]]
        edge_features = torch.cat([src, dst], dim=1)

        return torch.sigmoid(self.fc(edge_features)).squeeze()

# DataLoader for batched training
dataloader = DataLoader(protein_graph_dataset, batch_size=16, shuffle=True)

# Dynamically infer input feature dimension
sample_data = next(iter(dataloader))
in_dim = sample_data.x.shape[1]  # Automatically get feature dimension

# Initialize model, optimizer, and loss function
link_predictor = LinkPredictor(in_dim=in_dim, hidden_dim=32)
optimizer = torch.optim.Adam(link_predictor.parameters(), lr=0.01)
loss_fn = nn.BCELoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
link_predictor = link_predictor.to(device)

# Training loop
for epoch in range(100):
    link_predictor.train()
    total_loss = 0

    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        combined_edges, edge_labels = get_link_labels(batch.edge_index, batch.num_nodes)
        edge_labels = edge_labels.to(device)

        pred = link_predictor(batch.x, batch.edge_index, combined_edges)
        loss = loss_fn(pred, edge_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    if epoch % 10 == 0:
        acc = ((pred > 0.5) == edge_labels).float().mean().item()
        print(f'Epoch {epoch}, Loss: {total_loss:.4f}, Accuracy: {acc:.4f}')


Epoch 0, Loss: 9.5393, Accuracy: 0.4780
Epoch 10, Loss: 4.1608, Accuracy: 0.4957
Epoch 20, Loss: 4.1538, Accuracy: 0.5057
Epoch 30, Loss: 4.1465, Accuracy: 0.5250
Epoch 40, Loss: 4.1307, Accuracy: 0.5537
Epoch 50, Loss: 4.1049, Accuracy: 0.5630
Epoch 60, Loss: 4.0816, Accuracy: 0.5727
Epoch 70, Loss: 4.0584, Accuracy: 0.5741
Epoch 80, Loss: 4.0557, Accuracy: 0.5839
Epoch 90, Loss: 4.0481, Accuracy: 0.5788


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import StepLR

# Improved negative sampling strategy
def get_link_labels(edge_index, num_nodes):
    device = edge_index.device
    num_edges = edge_index.size(1)

    pos_edges = edge_index

    # Hard negative sampling: sample from nodes that are closer in graph structure
    neg_edges = torch.randint(0, num_nodes, (2, num_edges), device=device)
    
    # Avoid sampling existing edges
    mask = torch.isin(neg_edges, pos_edges).any(dim=0)
    while mask.any():
        neg_edges[:, mask] = torch.randint(0, num_nodes, (2, mask.sum()), device=device)
        mask = torch.isin(neg_edges, pos_edges).any(dim=0)

    # Concatenate positive & negative edges
    combined_edges = torch.cat([pos_edges, neg_edges], dim=1)
    edge_labels = torch.cat([torch.ones(num_edges), torch.zeros(num_edges)], dim=0).to(device)

    return combined_edges, edge_labels

# Enhanced Link Predictor Model
class LinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, hidden_dim)  # Added one more layer
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index, combined_edges):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        x = self.dropout(x)

        # Extract node embeddings
        src = x[combined_edges[0]]
        dst = x[combined_edges[1]]
        edge_features = torch.cat([src, dst], dim=1)

        return self.fc(edge_features).squeeze()  

# DataLoader
dataloader = DataLoader(protein_graph_dataset, batch_size=16, shuffle=True)

# Get input feature dimension dynamically
sample_data = next(iter(dataloader))
in_dim = sample_data.x.shape[1]  

# Model, Optimizer, and Loss
link_predictor = LinkPredictor(in_dim=in_dim, hidden_dim=64)  # Increased hidden size
optimizer = torch.optim.AdamW(link_predictor.parameters(), lr=0.005, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=20, gamma=0.8)  # Learning rate decay
loss_fn = nn.BCEWithLogitsLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
link_predictor = link_predictor.to(device)

# Training loop
for epoch in range(100):
    link_predictor.train()
    total_loss = 0

    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()

        combined_edges, edge_labels = get_link_labels(batch.edge_index, batch.num_nodes)
        edge_labels = edge_labels.to(device)

        pred = link_predictor(batch.x, batch.edge_index, combined_edges)
        loss = loss_fn(pred, edge_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    scheduler.step()

    # Compute accuracy
    with torch.no_grad():
        acc = ((torch.sigmoid(pred) > 0.5) == edge_labels).float().mean().item()

    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {total_loss:.4f}, Accuracy: {acc:.4f}')
