In [1]:
'''
This is a Python class definition for a continuous feature. It has an `__init__` method that initializes the feature with a name, 
a `__repr__` method that returns a string representation of the feature, an `__eq__` method that checks if two features are equal based 
on their names, and a `__hash__` method that returns a hash value for the feature based on its name. 
This class can be used to represent continuous features in a machine learning model or data analysis project.
'''
class ContinuousFeature:
  def __init__(self, name):
    self.name = name

  def __repr__(self):
    return f'<ContinuousFeature: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name

  def __hash__(self):
    return hash(self.name)

In [15]:
'''
This is a Python class definition for a categorical feature. It has an `__init__` method that takes in a name, a list of values, and a 
boolean flag indicating whether to add a null value. If the flag is True, a null value is added to the list of values. The class also has 
attributes for the name, whether it has a null value, the list of values, a mapping of values to indices, and a mapping of indices to values. 
If a null value is added, the class also has an attribute for the index of the null value.
'''
class CategoricalFeature:
  def __init__(self, name, values, add_null_value=True):
    self.name = name
    self.has_null_value = add_null_value
    if self.has_null_value:
      self.null_value = None
      values = (None,) + tuple(values)
    self.values = tuple(values)
    self.value_to_idx_mapping = {v: i for i, v in enumerate(values)}
    self.inv_value_to_idx_mapping = {i: v for v, i in
                                     self.value_to_idx_mapping.items()}
    
    if self.has_null_value:
      self.null_value_idx = self.value_to_idx_mapping[self.null_value]

  '''
  This method returns the index of the null value in the list of values if the categorical feature has a null value. 
  If the feature does not have a null value, it raises a `RuntimeError` with a message indicating that the feature has no null value.
  '''
  def get_null_idx(self):
    if self.has_null_value:
      return self.null_value_idx
    else:
      raise RuntimeError(f"Categorical variable {self.name} has no null value")

  '''
  This function takes in a value and returns the index of that value in the list of values for the categorical feature.
  It does this by looking up the value in the `value_to_idx_mapping` dictionary, which maps each value to its corresponding index in the list of values.
  '''
  def value_to_idx(self, value):
    return self.value_to_idx_mapping[value]
  
  '''This function takes in an index and returns the corresponding value in the list of values for the categorical feature.
  It does this by looking up the index in the `inv_value_to_idx_mapping` dictionary, which maps each index to its corresponding value in the list of values.
  '''
  def idx_to_value(self, idx):
    return self.inv_value_to_idx_mapping[idx]
  
  def __len__(self):
    return len(self.values)
  
  def __repr__(self):
    return f'<CategoricalFeature: {self.name}>'

  def __eq__(self, other):
    return self.name == other.name and self.values == other.values

  def __hash__(self):
    return hash((self.name, self.values))

In [4]:
# Atom types
ATOM_SYMBOLS = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 
                'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 
                'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 
                'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 
                'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 
                'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 
                'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg', 
                'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Fl', 'Lv', 'La', 'Ce',
                'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er',
                'Tm', 'Yb', 'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm',
                'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']

In [8]:
SYMBOLS_FEATURE = CategoricalFeature('atom_symbol', ATOM_SYMBOLS)

# Aromaticity
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalFeature('is_aromatic', AROMATIC_VALUES)

# Explicit valance
EXPLICIT_VALANCE_FEATURE = ContinuousFeature('explicit_valance')

# Implicit valance
IMPLICIT_VALANCE_FEATURE = ContinuousFeature('implicit_valance')

# Combine all four into one list of features
ATOM_FEATURES = [SYMBOLS_FEATURE,
                 AROMATIC_FEATURE,
                 EXPLICIT_VALANCE_FEATURE,
                 IMPLICIT_VALANCE_FEATURE]

In [10]:
# Bond types
BOND_TYPES = ['UNSPECIFIED', 'SINGLE', 'DOUBLE', 'TRIPLE', 'QUADRUPLE', 
              'QUINTUPLE', 'HEXTUPLE', 'ONEANDAHALF', 'TWOANDAHALF',
              'THREEANDAHALF','FOURANDAHALF', 'FIVEANDAHALF', 'AROMATIC', 
              'IONIC', 'HYDROGEN', 'THREECENTER',	'DATIVEONE', 'DATIVE',
              'DATIVEL', 'DATIVER', 'OTHER', 'ZERO']

In [13]:
TYPE_FEATURE = CategoricalFeature('bond_type', BOND_TYPES)

# Bond directions
BOND_DIRECTIONS = ['NONE', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT',
                   'ENDUPRIGHT', 'EITHERDOUBLE' ]
DIRECTION_FEATURE = CategoricalFeature('bond_direction', BOND_DIRECTIONS)

# Bond, James Bond
BOND_STEREO = ['STEREONONE', 'STEREOANY', 'STEREOZ', 'STEREOE', 
               'STEREOCIS', 'STEREOTRANS']
STEREO_FEATURE = CategoricalFeature('bond_stereo', BOND_STEREO)

# Aromaticity
AROMATIC_VALUES = [True, False]
AROMATIC_FEATURE = CategoricalFeature('is_aromatic', AROMATIC_VALUES)

# Combine all four into one list of features
BOND_FEATURES = [TYPE_FEATURE,
                 DIRECTION_FEATURE,
                 AROMATIC_FEATURE,
                 STEREO_FEATURE]

In [16]:
# Atom features
'''
This is a Python function that takes an RDKit atom object as input and returns a dictionary of features for that atom.
The features include the atom symbol, whether the atom is aromatic, the explicit and implicit valence of the atom.
'''
def get_atom_features(rd_atom):
  atom_symbol = rd_atom.GetSymbol()
  is_aromatic = rd_atom.GetIsAromatic()
  implicit_valance = float(rd_atom.GetImplicitValence())
  explicit_valance = float(rd_atom.GetExplicitValence())
  return {SYMBOLS_FEATURE: atom_symbol,
          AROMATIC_FEATURE: is_aromatic,
          EXPLICIT_VALANCE_FEATURE: explicit_valance,
          IMPLICIT_VALANCE_FEATURE: implicit_valance}

In [None]:
# Bond features
'''
This is a Python function called `get_bond_features` that takes a single argument `rd_bond`. The function extracts various
features of a chemical bond represented by `rd_bond` using functions from the RDKit library and returns a dictionary containing these features.
The features extracted include the bond type, direction, aromaticity, and stereochemistry. The keys of the dictionary correspond to the names of these features.
'''
def get_bond_features(rd_bond):
  bond_type = str(rd_bond.GetBondType())
  bond_stereo_info = str(rd_bond.GetStereo())
  bond_direction = str(rd_bond.GetBondDir())
  is_aromatic = rd_bond.GetIsAromatic()
  return {TYPE_FEATURE: bond_type,
          DIRECTION_FEATURE: bond_direction,
          AROMATIC_FEATURE: is_aromatic,
          STEREO_FEATURE: bond_stereo_info}

In [None]:
# Create dictionaries of the atoms and bonds in a molecule
'''
This code defines a function called `rdmol_to_graph` that takes a molecule object (`rd_mol`) as input and returns two dictionaries: `atoms` and `bonds`.
The `atoms` dictionary contains the features of each atom in the molecule, where the keys are the atom indices and the values are the features obtained by
calling the `get_atom_features` function. The `bonds` dictionary contains the features of each bond in the molecule, where the keys are frozensets of the
indices of the atoms that the bond connects, and the values are the features obtained by calling the `get_bond_features` function.
'''
def rdmol_to_graph(rd_mol):
  atoms = {rd_atom.GetIdx(): get_atom_features(rd_atom)
           for rd_atom in rd_mol.GetAtoms()}
  bonds = {frozenset((rd_bond.GetBeginAtomIdx(), rd_bond.GetEndAtomIdx())):
           get_bond_features(rd_bond) for rd_bond in rd_mol.GetBonds()}        
  return atoms, bonds

In [None]:
'''
This is a Python function that takes a SMILES string as input and returns a graph object.
The function first converts the SMILES string to a molecule object using the `MolFromSmiles` function from the RDKit library.
Then, it converts the molecule object to a graph object using the `rdmol_to_graph` function. Finally, it returns the graph object.
'''
def smiles_to_graph(smiles):
  rd_mol = MolFromSmiles(smiles)
  graph = rdmol_to_graph(rd_mol)
  return graph

In [None]:
class GraphDataset(Dataset):
  '''
  This is the initialization method for the `GraphDataset` class. It takes in several arguments including `graphs`, `labels`, `node_variables`, `edge_variables`, and `metadata`.
  It sets these arguments as attributes of the class instance and also creates additional attributes for categorical and continuous node and edge variables.
  The `metadata` argument is optional and is used to store additional information about each graph in the dataset. The method also includes assertions to ensure that the length
  of the `graphs` and `labels` lists are the same and that the length of the `metadata` list matches the length of the `graphs` list if `metadata` is not `None`.
  '''
  def __init__(self, *, graphs, labels, node_variables, edge_variables,
               metadata=None):
    '''
    Create a new graph dataset, 
    '''
    self.graphs = graphs
    self.labels = labels
    assert len(self.graphs) == len(self.labels), \
      "The graphs and labels lists must be the same length"
    self.metadata = metadata
    if self.metadata is not None:
      assert len(self.metadata) == len(self.graphs),\
        "The metadata list needs to be as long as the graphs"
    self.node_variables = node_variables
    self.edge_variables = edge_variables
    self.categorical_node_variables = [var for var in self.node_variables
                                       if isinstance(var, CategoricalFeature)]
    self.continuous_node_variables = [var for var in self.node_variables
                                      if isinstance(var, ContinuousFeature)]
    self.categorical_edge_variables = [var for var in self.edge_variables
                                       if isinstance(var, CategoricalFeature)]
    self.continuous_edge_variables = [var for var in self.edge_variables
                                      if isinstance(var, ContinuousFeature)]

  def __len__(self):
    return len(self.graphs)

  '''
  This method is creating a tensor of continuous node features for a given set of nodes. It first checks if there are any continuous node variables, and if not, returns None.
  If there are continuous node variables, it creates a tensor of zeros with dimensions (n_nodes, n_features), where n_nodes is the number of nodes and n_features is the number
  of continuous node variables. It then iterates over each node and extracts the continuous node features for that node, creating a tensor of those features.
  Finally, it assigns the tensor of continuous node features to the appropriate index in the overall tensor and returns the tensor.'''
  def make_continuous_node_features(self, nodes):
    if len(self.continuous_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.continuous_node_variables)
    continuous_node_features = torch.zeros((n_nodes, n_features),
                                           dtype=float_type)
    for node_idx, features in nodes.items():
      node_features = torch.tensor([features[continuous_feature]
                                    for continuous_feature
                                    in self.continuous_node_variables],
                                   dtype=float_type)
      continuous_node_features[node_idx] = node_features
    return continuous_node_features
      
The `make_categorical_node_features` method is creating a tensor of categorical node features for a given set of nodes. It first checks if there are any categorical node variables, and if not, returns None. If there are categorical node variables, it creates a tensor of zeros with dimensions (n_nodes, n_features), where n_nodes is the number of nodes and n_features is the number of categorical node variables. It then iterates over each node and extracts the categorical node features for that node, creating a tensor of those features. Finally, it assigns the tensor of categorical node features to the appropriate index in the overall tensor and returns the tensor.
  def make_categorical_node_features(self, nodes):
    if len(self.categorical_node_variables) == 0:
      return None
    n_nodes = len(nodes)
    n_features = len(self.categorical_node_variables)
    categorical_node_features = torch.zeros((n_nodes, n_features),
                                            dtype=categorical_type)

    for node_idx, features in nodes.items():
      for i, categorical_variable in enumerate(self.categorical_node_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_node_features[node_idx, i] = value_index

    return categorical_node_features

  '''
  The `make_continuous_edge_features` method is creating a tensor of continuous edge features for a given set of edges. 
  It first checks if there are any continuous edge variables, and if not, returns None. If there are continuous edge variables, it creates a
  tensor of zeros with dimensions (n_nodes, n_nodes, n_features), where n_nodes is the number of nodes and n_features is the number of continuous edge variables.
  It then iterates over each edge and extracts the continuous edge features for that edge, creating a tensor of those features.
  Finally, it assigns the tensor of continuous edge features to the appropriate index in the overall tensor and returns the tensor.
  '''
  def make_continuous_edge_features(self, n_nodes, edges):
    if len(self.continuous_edge_variables) == 0:
      return None
    n_features = len(self.continuous_edge_variables)
    continuous_edge_features = torch.zeros((n_nodes, n_nodes, n_features),
                                           dtype=float_type)
    for edge, features in edges.items():
      edge_features = torch.tensor([features[continuous_feature]
                                    for continuous_feature in
                                    self.continuous_edge_variables],
                                   dtype=float_type)
      u,v = edge
      continuous_edge_features[u, v] = edge_features
      if isinstance(edge, Set):
        continuous_edge_features[v, u] = edge_features

    return continuous_edge_features

  '''
  The `make_categorical_edge_features` method is creating a tensor of categorical edge features for a given set of edges.
  It first checks if there are any categorical edge variables, and if not, returns None. If there are categorical edge variables, it creates a tensor of
  zeros with dimensions (n_nodes, n_nodes, n_features), where n_nodes is the number of nodes and n_features is the number of categorical edge variables.
  It then iterates over each edge and extracts the categorical edge features for that edge, creating a tensor of those features. Finally, it assigns the tensor
  of categorical edge features to the appropriate index in the overall tensor and returns the tensor. If the graph is undirected, it also assigns the categorical
  edge features to the appropriate index in the transpose of the tensor.
  '''
  def make_categorical_edge_features(self, n_nodes, edges):
    if len(self.categorical_edge_variables) == 0:
      return None
    n_features = len(self.categorical_edge_variables)
    categorical_edge_features = torch.zeros((n_nodes, n_nodes, n_features),
                                            dtype=categorical_type)

    for edge, features in edges.items():
      u,v = edge
      for i, categorical_variable in enumerate(self.categorical_edge_variables):
          value = features[categorical_variable]
          value_index = categorical_variable.value_to_idx(value)
          categorical_edge_features[u, v, i] = value_index
          if isinstance(edge, Set):
            categorical_edge_features[v, u, i] = value_index

    return categorical_edge_features
  

  def __getitem__(self, index):
    # This is where the important stuff happens. We use our node and 
    # edge variable attributes to select what node and edge features to use.
    # In practice, we often do this as a pre-processing step, but here we do it 
    # in the getitem function for clarity
    '''
    The above code defines a class that represents a graph dataset. The `__getitem__` method is used to retrieve a specific graph from the dataset
    and preprocesses the node and edge features. The `get_node_variables` and `get_edge_variables` methods return information about the types of node
    and edge features in the dataset. The returned data is in the form of a dictionary containing information about the nodes, adjacency matrix, node and edge features,
    and label. The `metadata` attribute can also be included to provide additional information about the graph.
    '''

    graph = self.graphs[index]
    nodes, edges = graph
    n_nodes = len(nodes)
    continuous_node_features = self.make_continuous_node_features(nodes)
    categorical_node_features = self.make_categorical_node_features(nodes)
    continuous_edge_features = self.make_continuous_edge_features(n_nodes,
                                                                  edges)
    categorical_edge_features = self.make_categorical_edge_features(n_nodes,
                                                                    edges)

    label = self.labels[index]

    nodes_idx = sorted(nodes.keys())
    edge_list = sorted(edges.keys())

    n_nodes = len(nodes)
    adjacency_matrix = torch.zeros((n_nodes, n_nodes), dtype=float_type)
    for edge in edges:
      u, v = edge
      adjacency_matrix[u,v] = 1
      if isinstance(edge, Set):
        # This edge is unordered, assume this is a undirected graph
        adjacency_matrix[v,u] = 1

    data_record = {'nodes': nodes_idx,
                   'adjacency_matrix': adjacency_matrix,
                   'categorical_node_features': categorical_node_features,
                   'continuous_node_features': continuous_node_features,
                   'categorical_edge_features': categorical_edge_features,
                   'continuous_edge_features': continuous_edge_features,
                   'label': label}

    # If you need to add extra information (metadata about this graph) you can 
    # add an extra key-value pair here. The advantage of using a dict compared 
    # to a tuple is that the downstreams code doesn't break as long as at least 
    # the expected keys are present. The downside is that using a dict adds 
    # overhead (accessing a dict compared to unpacking a tuple).
    # A more robust implementation might actually make a separate class for 
    # dataset entires
    if self.metadata is not None:
      data_record['metadata'] = self.metadata[index]
      
    return data_record

  def get_node_variables(self):
    return {'continuous': self.continuous_node_variables,
            'categorical': self.categorical_node_variables}
  
  def get_edge_variables(self):
    return {'continuous': self.continuous_edge_variables,
            'categorical': self.categorical_edge_variables}

In [17]:
'''
This is a Python function that takes a list of dictionaries containing SMILES strings and labels as input, and returns a new GraphDataset object.
The function converts each SMILES string to a graph using the `smiles_to_graph` function, and appends the resulting graph, label, and metadata to separate lists.
These lists are then used to create a new GraphDataset object with the specified node and edge variables.
'''
def make_molecular_graph_dataset(smiles_records, atom_features=ATOM_FEATURES,
                                 bond_features=BOND_FEATURES):
  '''
  Create a new GraphDataset from a list of smiles_records dictionaries.
  These records should contain the key 'smiles' and 'label'. Any other keys
  will be saved as a 'metadata' record.
  '''
  graphs = []
  labels = []
  metadata = []
  for smiles_record in smiles_records:
    smiles = smiles_record['smiles']
    label = smiles_record['label']
    graph = smiles_to_graph(smiles)
    graphs.append(graph)
    labels.append(label)
    metadata.append(smiles_record)
  return GraphDataset(graphs=graphs, 
                      labels=labels, 
                      node_variables=atom_features, 
                      edge_variables=bond_features, 
                      metadata=metadata)