In [17]:
import os
import logging
from dotenv import load_dotenv
from typing import Union
from dataclasses import dataclass

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from rdkit import Chem, DataStructs
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import rdmolops, Draw
import sys

import torch
from torch import nn
from torchinfo import summary
from tqdm.notebook import tqdm


Molecule = Union[str, Chem.Mol]

logging.getLogger('rdkit').setLevel(logging.WARNING)

load_dotenv()
PROJECT_ROOT = os.getenv("PROJECT_ROOT") if os.getenv("PROJECT_ROOT") else "/home/elisey/dkr/F24-DKR-Project"
DATA_DIR = PROJECT_ROOT + "/data/"

sys.path.append(PROJECT_ROOT + "/graformer")
from graformer import GraphTransformer

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

## Open data

In [18]:
def is_toxic(row): 
    return int(any(row[1:])) 

data = pd.read_csv(DATA_DIR + "tox21.csv")
data.replace(np.nan, 0, inplace=True)
data['is_toxic'] = data.apply(is_toxic, axis=1)
data = data[['smiles', 'is_toxic']]

data

Unnamed: 0,smiles,is_toxic
0,CCOc1ccc2nc(S(N)(=O)=O)sc2c1,1
1,CCN1C(=O)NC(c2ccccc2)C1=O,0
2,CC[C@]1(O)CC[C@H]2[C@@H]3CCC4=CCCC[C@@H]4[C@H]...,0
3,CCCN(CC)C(CC)C(=O)Nc1c(C)cccc1C,0
4,CC(O)(P(=O)(O)O)P(=O)(O)O,0
...,...,...
7826,CCOc1nc2cccc(C(=O)O)c2n1Cc1ccc(-c2ccccc2-c2nnn...,0
7827,CC(=O)[C@H]1CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@]4(...,1
7828,C[C@]12CC[C@H]3[C@@H](CCC4=CC(=O)CC[C@@]43C)[C...,1
7829,C[C@]12CC[C@@H]3c4ccc(O)cc4CC[C@H]3[C@@H]1CC[C...,1


## Define feature extractor

In [19]:
@dataclass
class Embedding: 
    """
    A class used to represent the Embedding of a graph.

    node_embeddings : np.ndarray
        An array representing the embeddings of the nodes in the graph.
    adjacency : np.ndarray
        An adjacency matrix representing the connections between nodes in the graph.
    degree : np.ndarray
        An array representing the degree of each node in the graph.
    laplacian : np.ndarray
        A Laplacian matrix derived from the adjacency matrix of the graph.
    edge_index : np.ndarray
        An array representing the indices of the edges in the graph.
    """
    node_embeddings: np.ndarray
    adjacency: np.ndarray
    degree: np.ndarray
    laplacian: np.ndarray
    edge_index: np.ndarray  
    
    def __str__(self):
        res = f"Node embeddings: {self.node_embeddings.shape}\n"
        res += f"Adjacency matrix: {self.adjacency.shape}\n"
        res += f"Degree matrix: {self.degree.shape}\n"
        res += f"Laplacian matrix: {self.laplacian.shape}\n"
        res += f"Edge index: {self.edge_index.shape}"
        
        return res 

class FeatureExtractor:
    MORGAN_RADIUS = 2
    MORGAN_NUM_BITS = 2048
    
    def __init__(self):
        self.mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=2048)
        
    
    def morgan_features_generator(self, mol: Molecule, radius: int = MORGAN_RADIUS, num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
        """
        Generates a Morgan fingerprint for a molecule.

        :param mol: A molecule (i.e., either a SMILES or an RDKit molecule).
        :param radius: Morgan fingerprint radius.
        :param num_bits: Number of bits in Morgan fingerprint.
        :return: A 2D numpy array containing the Morgan fingerprint for each atom in the molecule.
        """
        mol = Chem.MolFromSmiles(mol) if isinstance(mol, str) else mol
        features = np.zeros((mol.GetNumAtoms(), num_bits))

        for atom in range(mol.GetNumAtoms()):
            env = Chem.FindAtomEnvironmentOfRadiusN(mol, radius, atom)
            amap = {}
            submol = Chem.PathToSubmol(mol, env, atomMap=amap)
            Chem.GetSSSR(submol)
            # features_vec = AllChem.GetMorganFingerprintAsBitVect(submol, radius, nBits=num_bits)
            features_vec = self.mfpgen.GetFingerprint(submol)
            DataStructs.ConvertToNumpyArray(features_vec, features[atom])

        # (n_atoms, embedding_size)
        return features
    
    def mol_to_graph(self, molecule: Chem.Mol) -> np.ndarray:
        """
        Converts a molecule to its graph representation.
        
        :param molecule: (Chem.Mol or str): The molecule to convert. It can be either a RDKit Mol object or a SMILES string.
        :return np.ndarray: The adjacency matrix representing the graph of the molecule.
        """
        
        mol = Chem.MolFromSmiles(molecule) if isinstance(molecule, str) else molecule
        graph = rdmolops.GetAdjacencyMatrix(mol)
        
        # (n_atoms, n_atoms)
        return graph
    
    def molecule_show(selff, molecule, title=False):
        """
        Displays a visual representation of a molecule from its SMILES string.
        
        :param molecule: (str): The SMILES string representation of the molecule.
        :param title: (bool, optional): If True, displays the SMILES string as the title of the plot. Defaults to False.
        """
        
        m = Chem.MolFromSmiles(molecule)
        img = Draw.MolToImage(m)
        
        if title: 
            plt.title(molecule)
    
        plt.imshow(img)
    
    def __call__(self, molecule: Molecule): 
        """
        Makes an embeddings of a molecule. 
        
        :param molecule: (str): The SMILES string representation of the molecule.
        :return Embedding: The embeddings of the molecule.
        """
        
        node_embeddings = self.morgan_features_generator(molecule, self.MORGAN_RADIUS, self.MORGAN_NUM_BITS)
        adjacency = self.mol_to_graph(molecule)
        degree = np.diag(np.sum(adjacency, axis=1))
        laplacian = degree - adjacency
        edge_index = np.array(np.nonzero(adjacency))

        return Embedding(node_embeddings, adjacency, degree, laplacian, edge_index)

feature_extractor = FeatureExtractor()

## Make torch dataset

In [20]:
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split

class MoleculeDataset(Dataset):
    def __init__(self, data: pd.DataFrame, feature_extractor: FeatureExtractor):
        """
        Initializes the MoleculeDataset.

        :param data: A pandas DataFrame containing the dataset with SMILES strings.
        :param feature_extractor: An instance of FeatureExtractor to generate features.
        """
        self.data = data
        self.feature_extractor = feature_extractor

    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Generates one sample of data.

        :param idx: The index of the sample to retrieve.
        :return: A tuple containing the node embeddings and adjacency matrix of the molecule.
        """
        smiles = self.data.iloc[idx]['smiles']
        embedding = self.feature_extractor(smiles)
        label = self.data.iloc[idx]['is_toxic']
        
        return (
            torch.tensor(embedding.node_embeddings, dtype=torch.float32).unsqueeze(0),
            torch.tensor(embedding.adjacency, dtype=torch.float32).unsqueeze(0),
            torch.tensor(label, dtype=torch.float32)
        )
# Example usage
train_data, test_data = train_test_split(data, test_size=0.1, random_state=42, shuffle=True)
train_dataset = MoleculeDataset(train_data, feature_extractor)
test_dataset = MoleculeDataset(test_data, feature_extractor)

## Define classifier task

In [21]:
class TransformerClassifier(nn.Module):
    def __init__(self, dim, depth):
        super(TransformerClassifier, self).__init__()
        self.transformer = GraphTransformer(
            dim=dim, depth=depth, accept_adjacency_matrix=True
        )
        self.classifier = nn.Sequential(
            nn.Linear(2048, 1024), 
            nn.GELU(), 
            nn.Linear(1024, 1)
        )
    
    def forward(self, node_embeddings, adjacency_matrix):
        # node_embeddings: [batch_size, max_atoms, 2048]
        # adjacency_matrix: [batch_size, max_atoms, max_atoms]
        transformer_output = self.transformer(node_embeddings, adj_mat=adjacency_matrix)
        # Assuming transformer_output is a tuple where the first element is the node embeddings
        node_embeddings_transformed = transformer_output[0]
        # Aggregate node embeddings, e.g., take mean over nodes
        graph_embedding = node_embeddings_transformed.mean(dim=1)
        # graph_embedding: [batch_size, 2048]
        return self.classifier(graph_embedding)

In [22]:
model = TransformerClassifier(dim=2048, depth=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.BCEWithLogitsLoss()

summary(model)

Layer (type:depth-idx)                             Param #
TransformerClassifier                              --
├─GraphTransformer: 1-1                            --
│    └─ModuleList: 2-1                             --
│    │    └─ModuleList: 3-1                        5,257,216
│    └─Identity: 2-2                               --
│    └─Embedding: 2-3                              4,096
├─Sequential: 1-2                                  --
│    └─Linear: 2-4                                 2,098,176
│    └─GELU: 2-5                                   --
│    └─Linear: 2-6                                 1,025
Total params: 7,360,513
Trainable params: 7,360,513
Non-trainable params: 0

In [23]:
def custom_collate(batch):
    max_atoms = max(node_embeddings.shape[1] for node_embeddings, _, _ in batch)
    batch_size = len(batch)
    
    # Initialize padded tensors
    padded_node_embeddings = torch.zeros(batch_size, max_atoms, 2048)
    padded_adjacency_matrices = torch.zeros(batch_size, max_atoms, max_atoms)
    labels = torch.zeros(batch_size)
    
    for i, (node_embeddings, adjacency_matrix, label) in enumerate(batch):
        n_atoms = node_embeddings.shape[1]
        padded_node_embeddings[i, :n_atoms, :] = node_embeddings.squeeze(0)
        padded_adjacency_matrices[i, :n_atoms, :n_atoms] = adjacency_matrix.squeeze(0)
        labels[i] = label
    
    return padded_node_embeddings, padded_adjacency_matrices, labels

In [24]:
from tqdm import tqdm
from torch.utils.data import DataLoader

def train(model, dataset, optimizer, criterion, epochs=10, device='cpu'):
    model = model.to(device)
    model.train()
    
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for node_embeddings, adjacency_matrix, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            node_embeddings = node_embeddings.to(device)
            adjacency_matrix = adjacency_matrix.to(device)
            labels = labels.to(device)
            
            outputs = model(node_embeddings, adjacency_matrix)
            loss = criterion(outputs.squeeze(), labels)
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        mean_epoch_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {mean_epoch_loss:.4f}")
        
# Example usage
train(model, train_dataset, optimizer, criterion, epochs=1, device=device)

Epoch 1/1: 100%|██████████| 221/221 [08:11<00:00,  2.23s/it]

Epoch 1/1, Loss: 0.6362





In [None]:
def validate(model, dataset, criterion, device='cpu'):
    model = model.to(device)
    model.eval()
    
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=custom_collate)
    
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():
        for node_embeddings, adjacency_matrix, labels in tqdm(dataloader, desc="Validating"):
            node_embeddings = node_embeddings.to(device)
            adjacency_matrix = adjacency_matrix.to(device)
            labels = labels.to(device)
            
            outputs = model(node_embeddings, adjacency_matrix)
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            
            predictions = torch.round(torch.sigmoid(outputs))
            correct_predictions += (predictions.cpu() == labels.cpu()).sum().item()
            total_samples += labels.size(0)
        
    mean_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    print(f"Validation Loss: {mean_loss:.4f}, Accuracy: {accuracy:.4f}")

# Example usage
validate(model, test_dataset, criterion, device=device)

Validating: 100%|██████████| 25/25 [00:23<00:00,  1.06it/s]

Validation Loss: 0.6058, Accuracy: 18.4362





# Inference

In [26]:
def predict(model, feature_extractor, smiles, device='cpu'):
    """
    Predicts the toxicity of a molecule given its SMILES string.

    :param model: The trained model.
    :param feature_extractor: The feature extractor used to generate embeddings.
    :param smiles: The SMILES string of the molecule.
    :param device: The device to run the model on (e.g., 'cpu' or 'cuda').
    :return: The predicted toxicity (0 or 1).
    """
    model = model.to(device)
    model.eval()

    # Generate embeddings for the molecule
    embedding = feature_extractor(smiles)
    node_embeddings = torch.tensor(embedding.node_embeddings, dtype=torch.float32).unsqueeze(0).to(device)
    adjacency_matrix = torch.tensor(embedding.adjacency, dtype=torch.float32).unsqueeze(0).to(device)

    # Make prediction
    with torch.no_grad():
        outputs = model(node_embeddings, adjacency_matrix)
        prediction = torch.round(torch.sigmoid(outputs)).item()

    return int(prediction)

# Example usage
new_smiles = "CCO"  # Replace with your SMILES string
prediction = predict(model, feature_extractor, new_smiles, device=device)
print(f"The predicted toxicity of the molecule with SMILES {new_smiles} is: {prediction}")

The predicted toxicity of the molecule with SMILES CCO is: 1
