<a href="https://colab.research.google.com/github/Sfonzie98/Dissertation/blob/main/graph_converter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Converter
> 

## Install and Import Usefull Libreries

In [None]:
!pip install rdkit-pypi
!pip install deepchem
!pip install -q dgl
!pip install networkx
!pip3 install torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
# Usefull libreries to load and process the data
import pandas as pd
import numpy as np
import os
import os.path as osp
from tqdm import tqdm

# Usefull libresies to create and featurize the graph
import torch
import torch_geometric
from torch_geometric import data
from torch_geometric.data import Data, Dataset, InMemoryDataset
import networkx as nx

# Useful libraries to visualize chemical structures and calculate chemical properties 
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
import deepchem as dc

## Create and Featurize the Graph

In [None]:
class MoleculeDataset(Dataset):
  def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
    """
    root = Where the dataset should be stored. This folder is split
    into raw_dir (downloaded dataset) and processed_dir (processed data). 
    """
    self.test = test
    self.filename = filename
    super(MoleculeDataset, self).__init__(root, transform, pre_transform)
        
  @property
  def raw_file_names(self):
    """ 
    If this file exists in raw_dir, the download is not triggered.
    (The download func. is not implemented here)  
    """
    return self.filename

  @property
  def processed_file_names(self):
    """ If these files are found in raw_dir, processing is skipped"""
    return 'processed.pt'

  def download(self):
    pass

  def process(self):
    self.data = pd.read_csv(self.raw_paths[0])
    for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
      mol_obj = Chem.MolFromSmiles(mol["Smiles"])    
      # Get adjacency info
      edge_index = self.get_adjacency_info(mol_obj)       
      # Get edge features
      edge_feats = self.get_edge_features(mol_obj)
      # Get node features
      node_feats = self.get_node_features(mol_obj)
      # Get labels info
      classes = self.get_labels(mol["Label"])

      # Create data object
      data = Data(edge_index=edge_index,
                  edge_attr=edge_feats,
                  x=node_feats,
                  y=classes,
                  smiles=mol["Smiles"]) 
      if self.test:
        torch.save(data, 
                   os.path.join(self.processed_dir, 
                                f'data_test_{index}.pt'))
      else:
        torch.save(data,
                   os.path.join(self.processed_dir,
                                f'data_{index}.pt'))
      

  def get_node_features(self, mol):
    """ 
    This will return a matrix / 2d array of the shape
    [Number of Nodes, Node Feature size]
    """
    all_node_feats = []

    for atom in mol.GetAtoms():
      node_feats = []
      # Feature 1: Atomic number        
      node_feats.append(atom.GetAtomicNum())
      # Feature 2: Atomic mass 
      node_feats.append(atom.GetMass())
      # Feature 3: Atom degree
      node_feats.append(atom.GetDegree())
      # Feature 4: The degree of an atom including Hs
      node_feats.append(atom.GetTotalDegree())
      # Feature 5: Explicit valence
      node_feats.append(atom.GetExplicitValence())
      # Feature 6: Implicit valence
      node_feats.append(atom.GetImplicitValence())
      # Feature 7: Formal charge
      node_feats.append(atom.GetFormalCharge())
      # Feature 8: Hybridization
      node_feats.append(atom.GetHybridization())
      # Feature 9: Aromaticity
      node_feats.append(atom.GetIsAromatic())
      # Feature 10: Total Num Hs
      node_feats.append(atom.GetTotalNumHs())
      # Feature 11: Radical Electrons
      node_feats.append(atom.GetNumRadicalElectrons())
      # Feature 12: In Ring
      node_feats.append(atom.IsInRing())
      # Feature 13: Atom is chiral center
      node_feats.append(atom.HasProp('_ChiralityPossible'))
      # Feature 14: Chirality
      node_feats.append(atom.GetChiralTag())

      # Append node features to matrix
      all_node_feats.append(node_feats)

    all_node_feats = np.asarray(all_node_feats)
    return torch.tensor(all_node_feats, dtype=torch.float)

  def get_edge_features(self, mol):
    """ 
    This will return a matrix / 2d array of the shape
    [Number of edges, Edge Feature size]
    """
    all_edge_feats = []

    for bond in mol.GetBonds():
      edge_feats = []
      # Feature 1: Bond type (as double)
      edge_feats.append(bond.GetBondTypeAsDouble())
      # Feature 2: Rings
      edge_feats.append(bond.IsInRing())
      # Feature 3: Conjugation
      edge_feats.append(bond.GetIsConjugated())
      # Feature 4: E/Z configuration
      edge_feats.append(bond.GetStereo())
      # Append node features to matrix (twice, per direction)
      all_edge_feats += [edge_feats, edge_feats]

    all_edge_feats = np.asarray(all_edge_feats)
    return torch.tensor(all_edge_feats, dtype=torch.float)

  def get_adjacency_info(self, mol):
    """
    We could also use rdmolops.GetAdjacencyMatrix(mol)
    but we want to be sure that the order of the indices
    matches the order of the edge features
    """
    edge_indices = []
    for bond in mol.GetBonds():
      i = bond.GetBeginAtomIdx()
      j = bond.GetEndAtomIdx()
      edge_indices += [[i, j], [j, i]]

    edge_indices = torch.tensor(edge_indices)
    edge_indices = edge_indices.t().to(torch.long).view(2, -1)
    return edge_indices

  def get_labels(self, label):
    label = np.asarray([label])
    return torch.tensor(label, dtype=torch.int64)

  def len(self):
    return self.data.shape[0]

  def get(self, idx):
    if self.test:
      data = torch.load(os.path.join(self.processed_dir,
                                     f'data_test_{idx}.pt'))
    else:
      data = torch.load(os.path.join(self.processed_dir,
                                     f'data_{idx}.pt'))   
      
    return data