In [1]:
!pip install torch_geometric
!pip install matplotlib
!pip install rdkit
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.1+cu117.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.1+cu117.html


# Helper function for visualization.
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt


def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()


def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()

import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import BatchNorm, PNAConv, global_add_pool
from torch_geometric.utils import degree

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m61.4/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Collecting rdkit
  Downloading rdkit-2024.3.6-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.3.6-cp310-cp310-manylinux_2_28_x86_64.whl (32.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.8/32.8 MB[0m 



In [2]:

import torch

def edge_index_to_adj(edge_index, num_nodes, edge_attr=None):
    """
    Convert edge index to an adjacency matrix with optional edge attributes.

    Args:
        edge_index (torch.Tensor): Edge indices of shape [2, num_edges].
        num_nodes (int): The number of nodes in the graph.
        edge_attr (torch.Tensor, optional): Edge attributes of shape [num_edges],
                                            where each entry represents a bond type
                                            (e.g., 1 for single, 2 for double).

    Returns:
        torch.Tensor: An adjacency matrix of shape [num_nodes, num_nodes] where
                      each entry reflects the type of edge (bond type) between nodes.
    """
    # Create an empty adjacency matrix
    adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float)

    # If no edge attributes are provided, default to 1 for all edges
    if edge_attr is None:
        edge_attr = torch.ones(edge_index.size(1), dtype=torch.float)
    else:
        edge_attr = edge_attr.to(torch.float)  # Ensure edge_attr is of the same type as adj

    # Fill the adjacency matrix using the edge indices and edge attributes
    adj[edge_index[0], edge_index[1]] = edge_attr
    adj[edge_index[1], edge_index[0]] = edge_attr  # For undirected graphs

    return adj
import pandas as pd
from rdkit import Chem

def load_qm9_smiles(csv_file):
    # Read the CSV file containing the QM9 dataset
    df = pd.read_csv(csv_file)

    # Extract SMILES strings
    smiles_list = df['smiles'].tolist()

    return smiles_list

# Example usage
csv_file = "qm9.csv"  # Replace with the path to your QM9 CSV file
qm9_smiles = load_qm9_smiles(csv_file)

print("Number of SMILES in QM9 dataset:", len(qm9_smiles))
print("Example SMILES:", qm9_smiles[1])
def remove_hydrogen_from_smiles(smiles_list):
    modified_smiles = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print("Invalid SMILES:", smiles)
            continue
        mol = Chem.RemoveHs(mol)
        modified_smiles.append(Chem.MolToSmiles(mol))
    return modified_smiles

# Example usage
# Assuming qm9_smiles is a list containing SMILES strings from the QM9 dataset
modified_smiles = remove_hydrogen_from_smiles(qm9_smiles)

import torch
from rdkit import Chem

def smiles_to_graph(smiles):
    # Parse the SMILES string
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None, None, None

    # Get node features (atomic numbers)
    atomic_numbers = [atom.GetAtomicNum() for atom in mol.GetAtoms()]

    # Get edge indices and edge types (bond types)
    edge_index = []
    edge_types = []
    for bond in mol.GetBonds():
        start_idx = bond.GetBeginAtomIdx()
        end_idx = bond.GetEndAtomIdx()
        edge_index.append([start_idx, end_idx])
        edge_index.append([end_idx, start_idx])  # Include both directions for undirected graph

        # Encode bond type as an integer
        bond_type = bond.GetBondType()
        if bond_type == Chem.rdchem.BondType.SINGLE:
            edge_types.append(1)
            edge_types.append(1)  # Add both directions
        elif bond_type == Chem.rdchem.BondType.DOUBLE:
            edge_types.append(2)
            edge_types.append(2)
        elif bond_type == Chem.rdchem.BondType.TRIPLE:
            edge_types.append(3)
            edge_types.append(3)
        elif bond_type == Chem.rdchem.BondType.AROMATIC:
            edge_types.append(4)
            edge_types.append(4)
        else:
            edge_types.append(0)  # Unknown bond type
            edge_types.append(0)

    # Convert edge indices and edge types to PyTorch tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # Shape [2, num_edges]
    edge_types = torch.tensor(edge_types, dtype=torch.long)  # Shape [num_edges]

    # Convert node features to PyTorch tensor
    node_features = torch.tensor(atomic_numbers, dtype=torch.float).unsqueeze(1)  # Shape [num_nodes, 1]

    return node_features, edge_index, edge_types


import torch
from torch_geometric.data import Data
import torch.nn.functional as F

filtered_dataset = []

# Define encoding mappings for atomic numbers to one-hot encodings
encoding_mappings = {
    7: [0, 0, 1, 0, 0],  # Nitrogen
    8: [0, 0, 0, 1, 0],  # Oxygen
    6: [0, 1, 0, 0, 0],  # Carbon
    9: [0, 0, 0, 0, 1]   # Fluorine
}

# Iterate over modified SMILES
for smile in modified_smiles:
    try:
        # Convert SMILES to graph representation with node features, edge index, and edge types
        node_features, edge_index1, edge_types = smiles_to_graph(smile)

        # Check if the graph has more than one node
        num_nodes = node_features.shape[0]
        if num_nodes > 1:
            # Convert node features to one-hot encoding
            one_hot_encoded = torch.tensor([encoding_mappings[num.item()] for num in node_features.squeeze()], dtype=torch.float32)


            # Create Data object with x, edge_index, and edge_attr for edge types
            graph = Data(x=one_hot_encoded, edge_index=edge_index1, edge_attr=edge_types, num_nodes=num_nodes)
            filtered_dataset.append(graph)
    except Exception as e:
        print(f"Error processing SMILES: {smile}. {e}")



Number of SMILES in QM9 dataset: 133885
Example SMILES: N


Defining GNN related code


In [3]:
from torch import nn
import torch
import math

def unsorted_segment_sum(data, segment_ids, num_segments, normalization_factor, aggregation_method: str):
    """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.
        Normalization: 'sum' or 'mean'.
    """
    result_shape = (num_segments, data.size(1))
    result = data.new_full(result_shape, 0)  # Init empty result tensor.
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    result.scatter_add_(0, segment_ids, data)
    if aggregation_method == 'sum':
        result = result / normalization_factor

    if aggregation_method == 'mean':
        norm = data.new_zeros(result.shape)
        norm.scatter_add_(0, segment_ids, data.new_ones(data.shape))
        norm[norm == 0] = 1
        result = result / norm
    return result

class GCL(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
        super(GCL, self).__init__()
        input_edge = input_nf * 2
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method
        self.attention = attention

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, edge_attr, edge_mask):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target], dim=1)
        else:
            out = torch.cat([source, target, edge_attr], dim=1)

        mij = self.edge_mlp(out)

        if self.attention:
            att_val = self.att_mlp(mij)
            out = mij * att_val
        else:
            out = mij

        if edge_mask is not None:
            out = out * edge_mask
        return out, mij

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = x + self.node_mlp(agg)
        return out, agg

    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        row, col = edge_index
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        if node_mask is not None:
            h = h * node_mask
        return h, mij

class GNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, out_node_nf,aggregation_method='sum', device='cpu',
                 act_fn=nn.SiLU(), n_layers=4, attention=False,
                 normalization_factor=100, ):
        super(GNN, self).__init__()
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        ### Encoder
        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(
                self.hidden_nf, self.hidden_nf, self.hidden_nf,
                normalization_factor=normalization_factor,
                aggregation_method=aggregation_method,
                edges_in_d=in_edge_nf, act_fn=act_fn,
                attention=attention))
        self.to(self.device)

    def forward(self, h, edges, edge_attr=None, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
        h = self.embedding_out(h)

        # Important, the bias of the last linear might be non-zero
        if node_mask is not None:
            h = h * node_mask
        return h
import numpy as np

def fully_connected_graph(num_nodes):
    # Generate all possible pairs of nodes
    nodes = np.arange(num_nodes)
    pairs = np.array(np.meshgrid(nodes, nodes)).T.reshape(-1, 2)

    # Filter out self-loops (optional, depending on your requirements)
    pairs = pairs[pairs[:, 0] != pairs[:, 1]]

    # Create the edge index tensor
    edge_index = torch.tensor(pairs, dtype=torch.long).t().contiguous()

    return edge_index

# Example usage for a fully connected graph with 4 nodes



Defining EGNN related code

In [4]:

def coord2diff(x, edge_index, norm_constant=1):
    row, col = edge_index
    coord_diff = x[row] - x[col]
    radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)
    norm = torch.sqrt(radial + 1e-8)
    coord_diff = coord_diff/(norm + norm_constant)
    return radial, coord_diff
class GCL(nn.Module):
    def __init__(self, input_nf, output_nf, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=0, nodes_att_dim=0, act_fn=nn.SiLU(), attention=False):
        super(GCL, self).__init__()
        input_edge = input_nf * 2
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method
        self.attention = attention

        self.edge_mlp = nn.Sequential(
            nn.Linear(input_edge + edges_in_d, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn)

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, output_nf))

        if self.attention:
            self.att_mlp = nn.Sequential(
                nn.Linear(hidden_nf, 1),
                nn.Sigmoid())

    def edge_model(self, source, target, edge_attr, edge_mask):
        if edge_attr is None:  # Unused.
            out = torch.cat([source, target], dim=1)
        else:
            out = torch.cat([source, target, edge_attr], dim=1)
        mij = self.edge_mlp(out)

        if self.attention:
            att_val = self.att_mlp(mij)
            out = mij * att_val
        else:
            out = mij

        if edge_mask is not None:
            out = out * edge_mask
        return out, mij

    def node_model(self, x, edge_index, edge_attr, node_attr):
        row, col = edge_index
        agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        if node_attr is not None:
            agg = torch.cat([x, agg, node_attr], dim=1)
        else:
            agg = torch.cat([x, agg], dim=1)
        out = x + self.node_mlp(agg)
        return out, agg

    def forward(self, h, edge_index, edge_attr=None, node_attr=None, node_mask=None, edge_mask=None):
        row, col = edge_index
        edge_feat, mij = self.edge_model(h[row], h[col], edge_attr, edge_mask)
        h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
        if node_mask is not None:
            h = h * node_mask
        return h, mij


class EquivariantUpdate(nn.Module):
    def __init__(self, hidden_nf, normalization_factor, aggregation_method,
                 edges_in_d=1, act_fn=nn.SiLU(), tanh=False, coords_range=10.0):
        super(EquivariantUpdate, self).__init__()
        self.tanh = tanh
        self.coords_range = coords_range
        input_edge = hidden_nf * 2 + edges_in_d
        layer = nn.Linear(hidden_nf, 1, bias=False)
        torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
        self.coord_mlp = nn.Sequential(
            nn.Linear(input_edge, hidden_nf),
            act_fn,
            nn.Linear(hidden_nf, hidden_nf),
            act_fn,
            layer)
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

    def coord_model(self, h, coord, edge_index, coord_diff, edge_attr, edge_mask):
        row, col = edge_index
        input_tensor = torch.cat([h[row], h[col], edge_attr], dim=1)
        if self.tanh:
            trans = coord_diff * torch.tanh(self.coord_mlp(input_tensor)) * self.coords_range
        else:
            trans = coord_diff * self.coord_mlp(input_tensor)
        if edge_mask is not None:
            trans = trans * edge_mask
        agg = unsorted_segment_sum(trans, row, num_segments=coord.size(0),
                                   normalization_factor=self.normalization_factor,
                                   aggregation_method=self.aggregation_method)
        coord = coord + agg
        return coord

    def forward(self, h, coord, edge_index, coord_diff, edge_attr=None, node_mask=None, edge_mask=None):
        coord = self.coord_model(h, coord, edge_index, coord_diff, edge_attr, edge_mask)
        if node_mask is not None:
            coord = coord * node_mask
        return coord


class EquivariantBlock(nn.Module):
    def __init__(self, hidden_nf, edge_feat_nf=2, device='cpu', act_fn=nn.SiLU(), n_layers=2, attention=True,
                 norm_diff=True, tanh=False, coords_range=15, norm_constant=1, sin_embedding=None,
                 normalization_factor=100, aggregation_method='sum'):
        super(EquivariantBlock, self).__init__()
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.coords_range_layer = float(coords_range)
        self.norm_diff = norm_diff
        self.norm_constant = norm_constant
        self.sin_embedding = sin_embedding
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

        for i in range(0, n_layers):
            self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=edge_feat_nf,
                                              act_fn=act_fn, attention=attention,
                                              normalization_factor=self.normalization_factor,
                                              aggregation_method=self.aggregation_method))
        self.add_module("gcl_equiv", EquivariantUpdate(hidden_nf, edges_in_d=edge_feat_nf, act_fn=nn.SiLU(), tanh=tanh,
                                                       coords_range=self.coords_range_layer,
                                                       normalization_factor=self.normalization_factor,
                                                       aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None, edge_attr=None):
        # Edit Emiel: Remove velocity as input
        distances, coord_diff = coord2diff(x, edge_index, self.norm_constant)
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        edge_attr = torch.cat([distances, edge_attr], dim=1)
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, edge_index, edge_attr=edge_attr, node_mask=node_mask, edge_mask=edge_mask)
        x = self._modules["gcl_equiv"](h, x, edge_index, coord_diff, edge_attr, node_mask, edge_mask)

        # Important, the bias of the last linear might be non-zero
        if node_mask is not None:
            h = h * node_mask
        return h, x


class EGNN(nn.Module):
    def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=3, attention=False,
                 norm_diff=True, out_node_nf=None, tanh=False, coords_range=15, norm_constant=1, inv_sublayers=2,
                 sin_embedding=False, normalization_factor=100, aggregation_method='sum'):
        super(EGNN, self).__init__()
        if out_node_nf is None:
            out_node_nf = in_node_nf
        self.hidden_nf = hidden_nf
        self.device = device
        self.n_layers = n_layers
        self.coords_range_layer = float(coords_range/n_layers) if n_layers > 0 else float(coords_range)
        self.norm_diff = norm_diff
        self.normalization_factor = normalization_factor
        self.aggregation_method = aggregation_method

        if sin_embedding:
            self.sin_embedding = SinusoidsEmbeddingNew()
            edge_feat_nf = self.sin_embedding.dim * 2
        else:
            self.sin_embedding = None
            edge_feat_nf = 2

        self.embedding = nn.Linear(in_node_nf, self.hidden_nf)
        self.embedding_out = nn.Linear(self.hidden_nf, out_node_nf)
        for i in range(0, n_layers):
            self.add_module("e_block_%d" % i, EquivariantBlock(hidden_nf, edge_feat_nf=edge_feat_nf, device=device,
                                                               act_fn=act_fn, n_layers=inv_sublayers,
                                                               attention=attention, norm_diff=norm_diff, tanh=tanh,
                                                               coords_range=coords_range, norm_constant=norm_constant,
                                                               sin_embedding=self.sin_embedding,
                                                               normalization_factor=self.normalization_factor,
                                                               aggregation_method=self.aggregation_method))
        self.to(self.device)

    def forward(self, h, x, edge_index, node_mask=None, edge_mask=None):
        # Edit Emiel: Remove velocity as input
        distances, _ = coord2diff(x, edge_index)
        if self.sin_embedding is not None:
            distances = self.sin_embedding(distances)
        h = self.embedding(h)
        for i in range(0, self.n_layers):
            h, x = self._modules["e_block_%d" % i](h, x, edge_index, node_mask=node_mask, edge_mask=edge_mask, edge_attr=distances)

        # Important, the bias of the last linear might be non-zero
        h = self.embedding_out(h)
        if node_mask is not None:
            h = h * node_mask
        return h, x


In [5]:


def fully_connected_graph_with_self_loops(num_nodes):
    """
    Generates edge indices for a fully connected graph with self-loops.

    Args:
        num_nodes (int): Number of nodes in the graph.

    Returns:
        torch.Tensor: Edge indices of the fully connected graph with self-loops.
    """
    # Create edge indices for a fully connected graph with self-loops
    edge_index = torch.tensor([[i, j] for i in range(num_nodes) for j in range(num_nodes)])

    return edge_index.t().contiguous()


Define the edge type prediction model

In [6]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, MessagePassing
from torch_geometric.data import Data

class BondTypePredictor(nn.Module):
    def __init__(self, num_node_features, hidden_dim, num_classes):
        """
        Initialize the BondTypePredictor model.

        Args:
            num_node_features (int): Number of input features for each node.
            hidden_dim (int): Dimension of hidden layers in GNN.
            num_classes (int): Number of bond types to predict.
        """
        super(BondTypePredictor, self).__init__()


        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)


        self.edge_classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x, edge_index):
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Node feature matrix of shape [num_nodes, num_node_features].
            edge_index (torch.Tensor): Edge index matrix of shape [2, num_edges].

        Returns:
            torch.Tensor: Predicted bond types for each edge of shape [num_edges, num_classes].
        """

        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)


        row, col = edge_index
        edge_representation = torch.cat([x[row], x[col]], dim=1)


        bond_type_logits = self.edge_classifier(edge_representation)

        return bond_type_logits


In [8]:
import torch
import torch.nn as nn
from torch.optim import Adam
# Parameters
node_feature_dim = 5
hidden_dim = 32
num_bond_types = 5

# Initialize model
model = BondTypePredictor(node_feature_dim, hidden_dim, num_bond_types)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)


In [9]:
def train(model, data_loader, criterion, optimizer):
    model.train()  # Set the model to training mode
    total_loss = 0

    for data in data_loader:
        optimizer.zero_grad()


        bond_type_logits = model(data.x[:,:5], data.edge_index)


        loss = criterion(bond_type_logits, data.edge_attr)


        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)


In [None]:

# Training parameters
num_epochs =20

# Training loop
for epoch in range(num_epochs):
    loss = train(model, filtered_dataset, criterion, optimizer)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}")


In [11]:
from torch_geometric.data import Data

# Convert SMILES to graph data
def prepare_graph(smiles):
    node_features, edge_index, edge_types = smiles_to_graph(smiles)
    if node_features is None:
        return None
    return Data(x=node_features, edge_index=edge_index)


In [None]:
import pickle
with open(f'edge_type_model.pkl', 'wb') as f:
    pickle.dump(model, f)