In [26]:
import torch
import random

import pandas as pd

from rdkit import Chem
from rdkit.Chem import AllChem

from pathlib import Path

from torch_geometric.data import Data

from tqdm import tqdm

In [27]:
import sys
import os
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)
from Dataset_Creation.script_pairs_creation_with_torch import smiles_to_torch_geometric, node_encoder, edge_encoder

In [28]:
csv_path = Path('..') / 'rndm_zinc_drugs_clean_3.csv'

zinc_df = pd.read_csv(csv_path)

In [29]:
def process_graph(smiles):

    data = smiles_to_torch_geometric(smiles)

    node_features = data.x
    edge_attr = data.edge_attr

    #encode the node features with the function node encoder one atom at a time
    node_features = torch.stack([node_encoder(atom) for atom in node_features])
    
    #encore the edge features with the function edge encoder one bond at a time
    edge_attr = torch.stack([edge_encoder(bond) for bond in edge_attr])

    #create new data object with the encoded node and edge features

    encoded_data = Data(x=node_features, edge_attr=edge_attr, edge_index=data.edge_index)

    return encoded_data

    

In [30]:
preprocessed_graph = []
for row in tqdm(zinc_df.itertuples()):
    smiles = row.smiles
    data = process_graph(smiles)
    preprocessed_graph.append(data)

torch.save(preprocessed_graph, 'preprocessed_graph.pt')

249455it [08:42, 477.78it/s]


In [32]:
data = preprocessed_graph[0]

In [33]:
def get_subgraph(data, indices, id_map):
    """
    Get a subgraph of a torch_geometric graph molecule based on given indices.

    Args:
    data (torch_geometric.data.Data): A PyTorch Geometric Data object representing the molecule.
    indices (list or torch.Tensor): List of node indices to extract as subgraph.

    Returns:
    subgraph_data (torch_geometric.data.Data): Subgraph of the torch_geometric graph molecule.
    """
    if not isinstance(indices, torch.Tensor):
        indices = torch.tensor(indices, dtype=torch.long)

    # Create a dictionary to map the old indices to new indices
    index_map = {old_index: id_map[old_index] for old_index in indices.tolist()}

    # Extract node features
    subgraph_x = torch.zeros(len(index_map), data.x.size(1))
    for i in range(len(indices)):
        subgraph_x[index_map[indices[i].item()]] = data.x[indices[i]]
        
    
    # Extract edges that are connected to the selected nodes
    mask = torch.tensor([src in index_map and tgt in index_map for src, tgt in data.edge_index.t().tolist()]).bool()
    subgraph_edge_index = data.edge_index[:, mask]

    # Relabel the edge indices according to the new node indices
    subgraph_edge_index = torch.tensor([[index_map[src], index_map[tgt]] for src, tgt in subgraph_edge_index.t().tolist()]).t()

    # Extract corresponding edge attributes
    subgraph_edge_attr = data.edge_attr[mask]

    # Create a new torch_geometric data object
    subgraph_data = Data(x=subgraph_x, edge_index=subgraph_edge_index, edge_attr=subgraph_edge_attr)

    return subgraph_data


In [34]:
def get_subgraph_with_terminal_nodes(data, num_atoms):
    """
    Get a subgraph of a torch_geometric graph molecule based on specific rules.

    Args:
    data (torch_geometric.data.Data): A PyTorch Geometric Data object representing the molecule.
    num_atoms (int): Desired number of atoms in the subgraph.

    Returns:
    subgraph_data (torch_geometric.data.Data): Subgraph of the torch_geometric graph molecule.
    """

    num_nodes = len(data.x)
    if num_atoms < 1 or num_atoms > num_nodes:
        raise ValueError("num_atoms must be between 1 and the number of nodes in the graph.")

    # Randomly select an atom
    start_atom = random.choice(range(num_nodes))

    # Initialize the queue and visited set
    queue = [start_atom]
    visited = set()

    id_map = {}
    new_id = 0

    # Breadth-first search
    while queue:
        current = queue.pop(0)
        visited.add(current)

        id_map[current] = new_id
        new_id += 1

        # Add neighbors to the queue
        neighbors = data.edge_index[:, data.edge_index[0] == current][1].tolist()
        random.shuffle(neighbors)
        for neighbor in neighbors:
            if neighbor not in visited and neighbor not in queue:
                queue.append(neighbor)

        # Stop if we've reached the desired number of atoms
        if len(visited) == num_atoms:
            break

    # Get the subgraph with the selected atoms
    subgraph_indices = torch.tensor(list(visited), dtype=torch.long)
    subgraph_data = get_subgraph(data, subgraph_indices, id_map)
    
    external_neighbors = []
    oldest_non_completed = min([i for i in neighbors if i in visited], key = lambda x:id_map[x])
    neighbors_oldest_non_completed = data.edge_index[:, data.edge_index[0] == oldest_non_completed][1].tolist()
    for neighbor in neighbors_oldest_non_completed:
        if neighbor not in visited:
            edge_attr_idx = (data.edge_index[0] == oldest_non_completed) & (data.edge_index[1] == neighbor)
            edge_data = data.edge_attr[edge_attr_idx][0]

            external_neighbors_edges = []
            for neighbor2 in data.edge_index[:, data.edge_index[0] == neighbor][1].tolist():
                if neighbor2 in visited and neighbor2 != oldest_non_completed:
                    # Get the index of the edge attribute
                    edge_attr_idx = (data.edge_index[0] == neighbor) & (data.edge_index[1] == neighbor2)
                    # Get the edge attribute
                    edge_attr = data.edge_attr[edge_attr_idx][0]
                    external_neighbors_edges.append((neighbor2, edge_attr))
            external_neighbors.append((neighbor, data.x[neighbor], edge_data, external_neighbors_edges))

    terminal_node_info = (oldest_non_completed, external_neighbors)

    return subgraph_data, terminal_node_info, id_map