### GraphAlchemist: A Graph-Based Autoregressive VAE for Molecule Generation

* **Author:** Hosein Mohammadi
* **Date:** July 2024
* **Contact:** [huseinmohammadi83@gmail.com](mailto:huseinmohammadi83@gmail.com)
* **LinkedIn:** [Hosein Mohammadi](https://www.linkedin.com/in/hosein-mohammadi-979b8a2b2/)
* **Project Repository:** [GNN-Molecule-Generator](https://github.com/Hosein541/GNN-Molecule-Generator)
---

### 📖 Project Overview

This project explores the *de novo* design of novel molecules using a deep learning approach. It leverages a **graph-based representation** of chemical structures and implements an **Autoregressive Variational Autoencoder (VAE)** built with Graph Neural Networks (GNNs) to generate new, valid compounds.

The model is trained on a dataset of molecules targeting the **EGFR protein**, sourced from the ChEMBL database. The ultimate goal is to learn the underlying patterns of this chemical space and generate novel molecules with desirable drug-like properties, evaluated using metrics like **QED (Quantitative Estimation of Drug-likeness)** and **SA Score (Synthetic Accessibility Score)**.

### Key Features
* **Graph-Based Representation:** Treats molecules as graphs, capturing rich structural information beyond simple string-based methods.
* **GAT Encoder:** A Graph Attention Network (GAT) encodes molecules into a continuous latent space.
* **Autoregressive Decoder:** A sophisticated decoder generates molecules step-by-step (atom-by-atom, bond-by-bond), leading to higher chemical validity.
* **Constrained Generation:** The final generation script is "smart," using chemical valence rules to guide the model and significantly improve the validity of outputs.
* **Tech Stack:** Built with **PyTorch**, **PyTorch Geometric**, and the **RDKit** cheminformatics toolkit.

### Project Workflow
1.  **Data Collection:** Fetching and preprocessing data for the EGFR target from ChEMBL.
2.  **Graph Conversion:** Transforming SMILES strings into graph data objects.
3.  **Model Training:** Training the Autoregressive VAE using a checkpointing system.
4.  **Molecule Generation:** Using the trained model to generate thousands of novel molecules.
5.  **Analysis & Visualization:** Scoring, selecting, and visualizing the best candidates in 3D.

In [None]:
# ===================================================================
# Section 1: Library Installation
# ===================================================================

# For accessing the ChEMBL database
!pip install -q chembl_webresource-client

# Core chemoinformatics and GNN libraries
!pip install -q rdkit-pypi
!pip install -q torch_geometric

# For 3D visualization
!pip install -q py3Dmol

# PyTorch is usually pre-installed in Colab, but this ensures it's there
!pip install -q torch

# ===================================================================
# Section 2: Library Imports
# ===================================================================

# --- Standard Python Libraries ---
import os
import sys
from pathlib import Path

# --- Core Data Science & Plotting Libraries ---
import time
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# --- Chemoinformatics Libraries (RDKit & ChEMBL) ---
from chembl_webresource_client.new_client import new_client
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, QED, RDConfig
# Add SA_Score module to Python path for synthesis accessibility calculation
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

# --- Deep Learning Libraries (PyTorch & PyTorch Geometric) ---
import torch
from torch import nn, optim
import torch.nn.functional as F
# The DataLoader for graph data comes from PyTorch Geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch.nn import Sequential, Linear, ReLU, Module, ModuleList
from torch_geometric.nn import GATConv, global_add_pool, global_mean_pool



print("✅ All required libraries imported successfully.")

### 1. Data Collection & Preprocessing

This initial step focuses on acquiring and cleaning the data needed to train our generative model. We source our data from the **ChEMBL database**, a large, open-access database of bioactive drug-like molecules.

#### Data Fetching
The process begins by targeting a specific protein, in this case, the **Epidermal Growth Factor Receptor (EGFR)**, which is identified by its ChEMBL ID `CHEMBL203`. We then fetch all associated bioactivity data, specifically looking for records that measure the **IC50** value. The IC50 is a common measure of a drug's potency, indicating the concentration of a drug required to inhibit a biological process by 50%.

For each record, we retrieve the molecule's structure in **SMILES** format and its corresponding IC50 value.

#### Data Cleaning & Transformation
The raw data requires several preprocessing steps to be suitable for machine learning:
* **Filtering:** We keep only the records with the "IC50" standard type to ensure consistency in our activity measurements.
* **Cleaning:** Rows with missing SMILES strings or IC50 values are removed to ensure data quality.
* **Calculating pIC50:** The IC50 values are typically on a logarithmic scale and can span many orders of magnitude. To make this data more tractable for the model, we convert the IC50 values (in nM) to **pIC50** values using the formula:
    $$
    \text{pIC50} = -\log_{10}(\text{IC50}_{\text{Molar}})
    $$
    A higher pIC50 value corresponds to a more potent molecule.

Finally, the cleaned and processed dataset, containing the molecule's SMILES string and its pIC50 value, is saved to `egfr_ic50_dataset.csv`. This file serves as the primary input for the rest of the project.

In [None]:
# --- Settings ---
target_id = "CHEMBL203"
max_records = 30000

activity = new_client.activity

print("⏳ Fetching data from ChEMBL...")

# EGFR data generator
activity_gen = activity.filter(
    target_chembl_id=target_id
).only(
    [
        "canonical_smiles",
        "standard_type",
        "standard_value",
        "molecule_chembl_id",
        "activity_id",
    ]
)

# Manually retrieve data up to max_records
records = []
for i, record in enumerate(activity_gen):
    if i >= max_records:
        break
    records.append(record)
    if i % 500 == 0:
        print(f"🔄 Received record {i}")

print(f"🎯 Number of records received: {len(records)}")

# --- Convert to DataFrame and Preprocessing ---
df = pd.DataFrame(records)

# Filter IC50 and numerical values
df = df[df["standard_type"] == "IC50"]
df = df[df["standard_value"].notna()]
df = df[df["canonical_smiles"].notna()]
df["standard_value"] = pd.to_numeric(df["standard_value"], errors="coerce")
df = df[df["standard_value"] > 0]

# Calculate pIC50
df["pIC50"] = -np.log10(df["standard_value"] * 1e-9)

# Final output
final_df = df[["molecule_chembl_id", "canonical_smiles", "standard_value", "pIC50"]]
final_df.columns = ["chembl_id", "smiles", "IC50_nM", "pIC50"]
final_df.to_csv("egfr_ic50_dataset.csv", index=False)

print("✅ File egfr_ic50_dataset.csv saved.")
print(f"📦 Final records: {final_df.shape[0]}")

In [None]:
# Load the dataset
df = pd.read_csv("egfr_ic50_dataset.csv")

# We only use the SMILES column for reconstruction
smiles_list = df['smiles'].tolist()
print(f"Total SMILES count: {len(smiles_list)}")

### 2. Graph Representation: Converting SMILES to Graphs

This code block is the cornerstone of our graph-based approach. Its primary purpose is to convert the entire dataset of **SMILES** strings into a format that a Graph Neural Network (GNN) can understand: a list of individual **graph** objects. This process is handled by the `smiles_to_graph` function.

#### The Conversion Process

For each molecule, the function constructs a graph where atoms are **nodes** and bonds are **edges**. To provide the GNN with rich chemical context, we extract a feature vector for every node and edge.

* **Nodes (Atoms):** Each atom in the molecule becomes a node in the graph. Its feature vector contains the following chemical properties:
    * Atomic Number
    * Hybridization Type (e.g., sp, sp2, sp3)
    * Total Number of Attached Hydrogens
    * Formal Charge
    * Aromaticity Status (True/False)
    * Degree (total number of bonds)

* **Edges (Bonds):** Each chemical bond between two atoms becomes an edge connecting the corresponding nodes. Its feature vector is a one-hot encoding that describes:
    * Bond Type (Single, Double, Triple, or Aromatic)
    * Whether the bond is part of a ring structure (True/False)

#### Processing the Entire Dataset

The main script iterates through every valid SMILES string from our cleaned dataset, applies this conversion function, and stores each resulting graph object in a list. A progress bar from `tqdm` visualizes this process.

The final output is `egfr_graph_dataset.pt`, a single file containing the list of all processed graph objects. Saving the data in this serialized format allows for extremely fast loading in the subsequent model training steps.

In [None]:
# Define a function to convert a SMILES string to a graph
def smiles_to_graph(smiles):
    """Converts a SMILES string to a PyTorch Geometric graph object."""
    mol = Chem.MolFromSmiles(smiles)

    # --- Validation step ---
    if mol is None:
        return None # Return None if the molecule is invalid

    mol = Chem.AddHs(mol)

    # Extract node features (atoms)
    node_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetHybridization(),
            atom.GetTotalNumHs(),
            atom.GetFormalCharge(),
            atom.GetIsAromatic(),
            atom.GetDegree(),
        ]
        node_features.append(features)
    x = torch.tensor(node_features, dtype=torch.float)

    # Extract edges and edge features (bonds)
    if mol.GetNumBonds() == 0: # Molecule with no bonds (only one atom)
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 5), dtype=torch.float)
    else:
        edge_indices, edge_features_list = [], []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])

            bond_type = [
                bond.GetBondType() == Chem.rdchem.BondType.SINGLE,
                bond.GetBondType() == Chem.rdchem.BondType.DOUBLE,
                bond.GetBondType() == Chem.rdchem.BondType.TRIPLE,
                bond.GetBondType() == Chem.rdchem.BondType.AROMATIC,
                bond.IsInRing(),
            ]
            edge_features_list.extend([bond_type, bond_type])

        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features_list, dtype=torch.float)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)

# Load the dataset
df = pd.read_csv('egfr_ic50_dataset.csv')

# Remove rows without SMILES strings
df.dropna(subset=['smiles'], inplace=True)


# Create a list of graphs using the function and display a progress bar
graph_list = []
for smiles in tqdm(df['smiles'], desc="Processing Molecules"):
    try:
        data = smiles_to_graph(smiles)
        if data is not None:
            graph_list.append(data)
    except Exception as e:
        print(f"Error processing SMILES {smiles}: {e}")


print("\n--- Processing complete ---")
print(f"Total molecules in the DataFrame: {len(df)}")
print(f"Number of valid molecules converted to graphs: {len(graph_list)}")

# Now 'graph_list' is our ready dataset for the model.
# You can save this list for later use.
torch.save(graph_list, 'egfr_graph_dataset.pt')

### 3. Model Architecture & Training

This section contains the core of our project: the definition of our advanced generative model and the complex loop used to train it.

#### Model Architecture 🧠

The model is a **Variational Autoencoder (VAE)** specifically designed for graph data. It consists of two main components:

1.  **Encoder (`GraphVAE_GAT`):**
    This network's job is to read a complete molecular graph and compress it into a meaningful, low-dimensional vector in a "latent space." It uses several **Graph Attention (GAT)** layers, which allow it to weigh the importance of different atoms and bonds to learn a highly informative representation of the molecule.

2.  **Decoder (`AutoregressiveDecoder`):**
    This is the generative part of the model. It takes a vector from the latent space and builds a new molecule step-by-step. This **autoregressive** process is powerful because each new decision is based on the state of the partially built molecule. At each step, the decoder makes three key predictions using separate MLP "heads":
    * **Node Type:** What kind of atom should be added next? (e.g., Carbon, Nitrogen)
    * **Edge Existence:** Should the new atom be connected to the existing atoms?
    * **Edge Type:** If a connection is made, what kind of bond should it be? (e.g., Single, Double)

#### The Training Process 🏋️‍♂️

Training this model is more complex than a standard supervised task. We use a method called **Teacher Forcing**.

During training, instead of letting the model use its own (potentially wrong) predictions to build the next step, we always provide it with the correct ground-truth structure from the dataset. This ensures the model learns the correct decision at every single step of molecule construction.

The model's performance is measured by a **multi-task loss function**, which is a combination of:
* **Node Type Loss:** Penalty for predicting the wrong atom type.
* **Edge Existence Loss:** Penalty for incorrect bond predictions.
* **Edge Type Loss:** Penalty for predicting the wrong bond type.
* **KL Divergence Loss:** A standard VAE loss that helps organize the latent space, making it better for generating novel molecules.

The final script initializes these models, sets up an Adam optimizer, and runs the training loop for a specified number of epochs, saving the model's progress along the way.

In [None]:

# ===================================================================
# Part 1: Definitions and Mappings (for code completeness)
# ===================================================================
ATOM_VOCAB = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'B', 'H', 'Unknown']
ATOM_MAP = {symbol: i for i, symbol in enumerate(ATOM_VOCAB)}
NUM_ATOM_TYPES = len(ATOM_VOCAB)

BOND_MAP = {
    Chem.rdchem.BondType.SINGLE: 0,
    Chem.rdchem.BondType.DOUBLE: 1,
    Chem.rdchem.BondType.TRIPLE: 2,
    Chem.rdchem.BondType.AROMATIC: 3,
}
NUM_BOND_TYPES = len(BOND_MAP)

# ===================================================================
# Part 2: Defining the Model Architectures (Encoder and Decoder)
# ===================================================================

class GraphVAE_GAT(Module):
    def __init__(self, node_feature_dim, latent_dim=128, heads=4):
        super(GraphVAE_GAT, self).__init__()
        self.latent_dim = latent_dim
        self.encoder_conv1 = GATConv(node_feature_dim, 256, heads=heads)
        self.encoder_conv2 = GATConv(256 * heads, 256, heads=heads)
        self.encoder_conv3 = GATConv(256 * heads, latent_dim * 2, heads=1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.encoder_conv1(x, edge_index))
        x = F.elu(self.encoder_conv2(x, edge_index))
        x = self.encoder_conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        mu = x[:, :self.latent_dim]
        log_var = x[:, self.latent_dim:]
        return mu, log_var, x



class AutoregressiveDecoder(Module):
    def __init__(self, node_feature_dim, latent_dim, hidden_dim=128):
        super(AutoregressiveDecoder, self).__init__()
        self.gnn = GATConv(node_feature_dim + latent_dim, hidden_dim, heads=4)

        mlp_input_dim = hidden_dim * 4

        self.mlp_node_type = Sequential(
            Linear(mlp_input_dim, hidden_dim), nn.ReLU(),
            Linear(hidden_dim, NUM_ATOM_TYPES)
        )
        self.mlp_predict_edge = Sequential(
            Linear(mlp_input_dim * 2, hidden_dim), nn.ReLU(),
            Linear(hidden_dim, 1)
        )
        # ✅ NEW PREDICTION HEAD
        self.mlp_edge_type = Sequential(
            Linear(mlp_input_dim * 2, hidden_dim), nn.ReLU(),
            Linear(hidden_dim, NUM_BOND_TYPES) # Predicts one of 4 bond types
        )

In [None]:
def train_autoregressive_epoch(encoder, decoder, dataloader, optimizer, device):
    encoder.train()
    decoder.train()
    total_epoch_loss = 0

    for data_batch in tqdm(dataloader, desc="Training Epoch"):
        optimizer.zero_grad()
        # ... (The first part of the loop is the same)
        data_batch = data_batch.to(device)
        try:
            mu_batch, log_var_batch, _ = encoder(data_batch)
        except (IndexError, TypeError) as e:
            print(f"Skipping batch due to an error: {e}")
            continue

        std = torch.exp(0.5 * log_var_batch)
        eps = torch.randn_like(std)
        z_batch = mu_batch + eps * std

        batch_loss_list = []
        for i in range(data_batch.num_graphs):
            # ... (The loop over molecules is the same)
            z = z_batch[i]
            true_graph = data_batch.get_example(i)
            num_nodes = true_graph.num_nodes
            if num_nodes <= 1: continue

            partial_x = true_graph.x[0].unsqueeze(0)
            partial_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)

            molecule_loss_list = []
            for t in range(1, num_nodes):
                # ... (GNN pass is the same)
                z_expanded = z.repeat(t, 1)
                gnn_input_x = torch.cat([partial_x, z_expanded], dim=1)
                node_embeddings = F.elu(decoder.gnn(gnn_input_x, partial_edge_index))
                graph_embedding = global_add_pool(node_embeddings, torch.zeros(t, dtype=torch.long, device=device))

                # --- Loss Calculation ---
                # 1. Node Type Loss (same as before)
                true_node_type_idx = ATOM_MAP.get(Chem.Atom(int(true_graph.x[t, 0])).GetSymbol(), ATOM_MAP['Unknown'])
                pred_node_type_logits = decoder.mlp_node_type(graph_embedding)
                loss_node_type = F.cross_entropy(pred_node_type_logits, torch.tensor([true_node_type_idx], device=device))
                molecule_loss_list.append(loss_node_type)

                # 2. Edge Prediction Loss
                if t > 0:
                    new_node_emb = node_embeddings[-1].unsqueeze(0).repeat(t, 1)
                    existing_node_embs = node_embeddings
                    edge_mlp_input = torch.cat([new_node_emb, existing_node_embs], dim=1)

                    # Edge Existence Loss (same as before)
                    pred_edge_logits = decoder.mlp_predict_edge(edge_mlp_input).squeeze(-1)
                    true_edges_exist = torch.zeros(t, device=device)
                    for j in range(t):
                        if ((true_graph.edge_index[0] == t) & (true_graph.edge_index[1] == j)).any():
                            true_edges_exist[j] = 1.0
                    loss_edge_exist = F.binary_cross_entropy_with_logits(pred_edge_logits, true_edges_exist)
                    if not torch.isnan(loss_edge_exist):
                        molecule_loss_list.append(loss_edge_exist)

                    # ✅ NEW: Edge Type Loss
                    # We only calculate this for bonds that actually exist
                    true_edge_indices = (true_edges_exist == 1.0).nonzero().squeeze(-1)
                    if true_edge_indices.numel() > 0:
                        # Get embeddings for only the nodes that are truly connected
                        connected_mlp_input = edge_mlp_input[true_edge_indices]
                        pred_edge_type_logits = decoder.mlp_edge_type(connected_mlp_input)

                        # Get the ground truth bond types
                        true_bond_types = []
                        for j in true_edge_indices:
                            bond_mask = ((true_graph.edge_index[0] == t) & (true_graph.edge_index[1] == j))
                            bond_type_feature_vector = true_graph.edge_attr[bond_mask][0]
                            # bond_type_idx = bond_type_feature_vector.nonzero().item()
                            bond_type_idx = bond_type_feature_vector[:4].nonzero().item()
                            true_bond_types.append(bond_type_idx)

                        loss_edge_type = F.cross_entropy(pred_edge_type_logits, torch.tensor(true_bond_types, device=device))
                        if not torch.isnan(loss_edge_type):
                            molecule_loss_list.append(loss_edge_type)

                # --- Teacher Forcing (same as before) ---
                partial_x = true_graph.x[:t+1]
                edge_mask = (true_graph.edge_index[0] <= t) & (true_graph.edge_index[1] <= t)
                partial_edge_index = true_graph.edge_index[:, edge_mask]

            # ... (rest of the function is the same)
            if molecule_loss_list:
                batch_loss_list.append(torch.stack(molecule_loss_list).mean())
        if batch_loss_list:
            final_batch_loss = torch.stack(batch_loss_list).mean()
            kl_loss = -0.5 * torch.sum(1 + log_var_batch - mu_batch.pow(2) - log_var_batch.exp())
            final_batch_loss += (kl_loss / data_batch.num_graphs) * 0.001
            final_batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)
            optimizer.step()
            total_epoch_loss += final_batch_loss.item()

    return total_epoch_loss / len(dataloader)

In [None]:
# ===================================================================
# Section 4: Final Initialization and Training Loop
# ===================================================================
# --- Model Parameters ---
LATENT_DIM = 128
HIDDEN_DIM = 128
NODE_FEATURE_DIM = graph_list[0].num_node_features

# --- Final Preparation ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Build Models ---
encoder = GraphVAE_GAT(node_feature_dim=NODE_FEATURE_DIM, latent_dim=LATENT_DIM).to(device)
decoder = AutoregressiveDecoder(node_feature_dim=NODE_FEATURE_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM).to(device)

# --- Optimizer ---
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.0001)

# --- Data Loader ---
# ❗️ Tip: For initial testing, use a small subset
small_graph_list = graph_list[:4096]
data_loader = DataLoader(small_graph_list, batch_size=512, shuffle=True)
# data_loader = DataLoader(graph_list, batch_size=16, shuffle=True)


# --- Start Training ---
NUM_EPOCHS = 2 # ❗️ For starting, test with a small number (e.g., 1 or 2)
print(f"Starting training for {NUM_EPOCHS} epochs...")
for epoch in range(1, NUM_EPOCHS + 1):
    start_time = time.time()
    loss = train_autoregressive_epoch(encoder, decoder, data_loader, optimizer, device)
    end_time = time.time()
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Time: {(end_time-start_time):.2f}s")

### 4. Molecule Generation & Analysis

This final section brings the entire project to its conclusion. Here, we use the fully trained autoregressive model to generate a large library of novel molecules. We then evaluate these generated molecules using standard chemoinformatics metrics to assess their quality and potential as drug candidates.

#### The Generation Process

The core of this step is the `generate_final_robust_molecule` function, which serves as the "smart" generator. It combines several key strategies to produce high-quality output:
* **Autoregressive Construction:** It builds molecules step-by-step, making an intelligent decision at each stage based on the partially constructed graph.
* **"Adventurous" Sampling:** To prevent the model from generating disconnected fragments ("atomic soup"), the function forces the new atom to connect to the `k` most probable existing atoms. This ensures the final output is a single, connected molecule.
* **Valence Constraints:** Before adding any bond, the function performs a rigorous check against known chemical valence rules (e.g., Carbon's valence is 4). This significantly increases the rate of chemically valid outputs by preventing the formation of "graph monsters."

This process is repeated 1000 times to create a large and diverse pool of candidate molecules.

#### Molecule Scoring

After generation, every unique and valid molecule is passed to a scoring function. We evaluate each molecule based on two critical, industry-standard metrics:

* **QED (Quantitative Estimation of Drug-likeness):** This metric scores a molecule on a scale from 0 to 1, based on how well its physicochemical properties align with those of known oral drugs. **A higher QED score is better.**
* **SA Score (Synthetic Accessibility Score):** This score estimates how difficult it would be to synthesize the molecule in a lab. It ranges from 1 (very easy) to 10 (extremely difficult). **A lower SA Score is better.**

The final output of this cell is a pandas DataFrame, `df_scores`, containing the SMILES strings of the successfully generated molecules and their corresponding QED and SA scores, ready for the final statistical analysis and visualization.

In [None]:
# ===================================================================
# Section 1: Required Dictionaries & Mappings
# ===================================================================

# --- Atom Vocab ---
ATOM_VOCAB = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P', 'B', 'H', 'Unknown']
ATOM_MAP = {symbol: i for i, symbol in enumerate(ATOM_VOCAB)}
INV_ATOM_MAP = {i: symbol for symbol, i in ATOM_MAP.items()}

# --- Bond Vocab & Mappings ---
INV_BOND_TYPE_MAP = {
    0: Chem.rdchem.BondType.SINGLE,
    1: Chem.rdchem.BondType.DOUBLE,
    2: Chem.rdchem.BondType.TRIPLE,
    3: Chem.rdchem.BondType.AROMATIC
}

RDKIT_BOND_VALENCE_MAP = {
    Chem.rdchem.BondType.SINGLE: 1.0,
    Chem.rdchem.BondType.DOUBLE: 2.0,
    Chem.rdchem.BondType.TRIPLE: 3.0,
    Chem.rdchem.BondType.AROMATIC: 1.5
}

# --- Valence Constraints ---
MAX_VALENCE = {
    'C': 4, 'N': 3, 'O': 2, 'F': 1, 'Cl': 1, 'Br': 1, 'I': 1, 'S': 2,
    'P': 3, 'B': 3, 'H': 1
}

# --- Model Parameters (ensure this matches your model's training) ---
# Example: NODE_FEATURE_DIM = 6 (if you get it from graph_list[0].num_node_features)


# ===================================================================
# Section 2: Final Robust Generation Function
# ===================================================================

def generate_final_robust_molecule(decoder, latent_vector, k=2, max_nodes=40, device='cpu'):
    """
    Generates a single molecule using robust valence constraints and bond type predictions.
    """
    with torch.no_grad():
        # Data structures to build the molecule
        nodes = [6] # Start with a Carbon atom (atomic number)
        edges = [] # List of tuples (u, v, bond_type)

        # Keep track of valences throughout generation
        valences = torch.zeros(max_nodes, dtype=torch.float, device=device)

        for t in range(1, max_nodes):
            # Build the current partial graph for the model
            partial_x = torch.zeros(t, NODE_FEATURE_DIM, device=device)
            for i, atomic_num in enumerate(nodes):
                partial_x[i, 0] = atomic_num

            if not edges:
                partial_edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
            else:
                edge_tens = torch.tensor([(u,v) for u,v,bt in edges], device=device).t()
                partial_edge_index = torch.cat([edge_tens, edge_tens.flip(0)], dim=1)

            # Get Predictions from the Decoder
            z_expanded = latent_vector.repeat(t, 1)
            gnn_input_x = torch.cat([partial_x, z_expanded], dim=1)
            node_embeddings = F.elu(decoder.gnn(gnn_input_x, partial_edge_index))
            graph_embedding = global_add_pool(node_embeddings, torch.zeros(t, dtype=torch.long, device=device))

            pred_node_type_logits = decoder.mlp_node_type(graph_embedding)
            new_node_type_idx = torch.multinomial(F.softmax(pred_node_type_logits, dim=-1), 1).squeeze().item()

            if new_node_type_idx == ATOM_MAP['Unknown']:
                break

            new_atom_symbol = INV_ATOM_MAP[new_node_type_idx]
            nodes.append(Chem.Atom(new_atom_symbol).GetAtomicNum())
            max_val_new = MAX_VALENCE.get(new_atom_symbol, 100)

            if t > 0 and t < max_nodes:
                new_node_emb = node_embeddings[-1].unsqueeze(0).repeat(t, 1)
                edge_mlp_input = torch.cat([new_node_emb, node_embeddings], dim=1)

                pred_edge_logits = decoder.mlp_predict_edge(edge_mlp_input).squeeze(-1)
                edge_probs = torch.sigmoid(pred_edge_logits)

                num_to_connect = min(k, t)
                top_k_indices = torch.topk(edge_probs, k=num_to_connect).indices

                top_k_mlp_input = edge_mlp_input[top_k_indices]
                pred_edge_type_logits = decoder.mlp_edge_type(top_k_mlp_input)
                bond_type_indices = torch.multinomial(F.softmax(pred_edge_type_logits, dim=-1), 1).squeeze(-1)

                for i, j in enumerate(top_k_indices):
                    j = j.item()
                    bond_type_idx = bond_type_indices[i].item() if bond_type_indices.dim() > 0 else bond_type_indices.item()
                    rdkit_bond_type = INV_BOND_TYPE_MAP.get(bond_type_idx, Chem.rdchem.BondType.SINGLE)
                    bond_valence = RDKIT_BOND_VALENCE_MAP.get(rdkit_bond_type, 1.0)

                    atom_j_symbol = Chem.Atom(nodes[j]).GetSymbol()
                    max_val_j = MAX_VALENCE.get(atom_j_symbol, 100)

                    if valences[j] + bond_valence <= max_val_j and valences[t] + bond_valence <= max_val_new:
                        edges.append((t, j, rdkit_bond_type))
                        valences[t] += bond_valence
                        valences[j] += bond_valence

    # --- Final step: Convert the collected nodes and edges to SMILES ---
    mol = Chem.RWMol()
    for atomic_num in nodes:
        mol.AddAtom(Chem.Atom(atomic_num))

    for u, v, bond_type in edges:
        mol.AddBond(u, v, bond_type)

    try:
        Chem.SanitizeMol(mol)
        smi = Chem.MolToSmiles(mol)
        # Final check for disconnected fragments
        if '.' in smi:
            return "Disconnected"
        return smi
    except Exception:
        return "Invalid"

In [None]:
# Ensure your trained models are loaded and in eval() mode
# Ensure all helper functions and maps are defined
decoder.eval()
encoder.eval()
print("🧪 Generating 1000 molecules...")
generated_smiles = []
for i in tqdm(range(1000), desc="Generating Molecules"):
    random_z = torch.randn(1, LATENT_DIM).to(DEVICE)
    # Using k=2 or k=3 can sometimes create more complex molecules
    smiles = generate_final_robust_molecule(decoder, random_z, k=2, device=DEVICE)

    # We only keep valid, single-fragment molecules
    if smiles not in ["Invalid", "Disconnected"]:
        generated_smiles.append(smiles)

# Remove duplicates
unique_smiles = list(set(generated_smiles))
print(f"\n✅ Generated {len(unique_smiles)} unique, valid molecules.")


scores = []
for smi in tqdm(unique_smiles, desc="Scoring Molecules"):
    mol = Chem.MolFromSmiles(smi)
    if mol:
        qed_score = QED.qed(mol)
        sa_score = sascorer.calculateScore(mol)
        scores.append({'smiles': smi, 'qed': qed_score, 'sa_score': sa_score})

# Create a DataFrame for easy analysis
df_scores = pd.DataFrame(scores)

### 5. Final Analysis & Visualization

This is the final and most rewarding step of the project. After generating a large pool of molecules, we now analyze them quantitatively to identify the most promising candidates and visualize the best result.

#### Selecting the Best Candidates 🏆

To rank the generated molecules, we define a custom **`combined_score`**. This score is designed to find an optimal balance between drug-likeness and ease of synthesis. The formula used is:

`combined_score = 2 * QED - (SA_Score / 10)`

This formula prioritizes molecules with a high **QED** (multiplied by 2 for emphasis) while penalizing those with a high **SA Score**. The script first identifies the single best molecule based on this score and then selects the top 100 candidates for a more detailed statistical analysis.

#### Statistical and Visual Analysis 📊

To understand the properties of the high-quality molecules our model generated, the code produces a comprehensive analysis of the top 100 candidates:

* **Descriptive Statistics:** A table is printed, summarizing key statistics like the mean, median, standard deviation, and quartiles for the QED, SA Score, and combined scores. This provides a quantitative snapshot of the model's performance.
* **Visual Plots:** Three plots are generated to visualize the distribution and relationship of these properties:
    1.  **QED Distribution:** Shows the range and frequency of drug-likeness scores among the top molecules.
    2.  **SA Score Distribution:** Shows the distribution of synthetic accessibility scores.
    3.  **QED vs. SA Scatter Plot:** Visualizes the trade-off between the two key metrics for the entire set of top candidates.

#### 3D Visualization of the Top Molecule 🔬

As a final step, the single best molecule is rendered in 3D. The script uses **RDKit** to:
1.  Add hydrogen atoms to the 2D structure.
2.  Embed the molecule into 3D space to generate realistic coordinates.
3.  Optimize the 3D geometry using the UFF (Universal Force Field) algorithm.

The resulting 3D structure is then displayed using **py3Dmol**, providing an interactive view of the novel molecule designed by our model.

In [None]:
# We want to maximize this score (high QED, low SA Score)
# Dividing sa_score by 10 helps normalize its scale relative to QED
df_scores['combined_score'] = 2* df_scores['qed'] - (df_scores['sa_score'] / 10)

# Find the molecule with the best combined score
best_molecule_row = df_scores.sort_values(by='combined_score', ascending=False).iloc[0]
best_smiles = best_molecule_row['smiles']
print("\n🏆 Best molecule found:")
print(best_molecule_row)


df_top100 = df_scores.sort_values(by='combined_score', ascending=False).head(100)

print("\n\n\n🏆 Top 100 molecules selected. Here's a preview:")
print(df_top100.head())


print("\n\n\n📊 Descriptive Statistics for the Top 50 Generated Molecules:")

# The .describe() method calculates count, mean, std, min,
# 25th percentile (Q1), 50th percentile (median), 75th percentile (Q3), and max.
descriptive_stats = df_top100[['qed', 'sa_score', 'combined_score']].describe()

# Display the results
print(descriptive_stats)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: QED Distribution
axes[0].hist(df_top100['qed'], bins=30, color='tab:blue')
axes[0].axvline(df_top100['qed'].mean(), color='red', linestyle='--', label=f"Mean: {df_top100['qed'].mean():.2f}")
axes[0].set_title('QED Distribution (Top 100)')
axes[0].set_xlabel('QED (higher is better)')
axes[0].set_ylabel('Count')

# Plot 2: SA Score Distribution
axes[1].hist(df_top100['sa_score'], bins=20, color='tab:blue')
axes[1].axvline(df_top100['sa_score'].mean(), color='red', linestyle='--', label=f"Mean: {df_top100['sa_score'].mean():.2f}")
axes[1].set_title('SA Score Distribution (Top 100)')
axes[1].set_xlabel('SA (lower = easier)')
axes[1].set_ylabel('Count')

# Plot 3: QED vs SA Score
axes[2].scatter(df_top100['qed'], df_top100['sa_score'], alpha=0.5)
axes[2].axvline(df_top100['qed'].mean(), color='red', linestyle='--')
axes[2].axhline(df_top100['sa_score'].mean(), color='red', linestyle='--')
axes[2].set_title('QED vs SA Score (Top 100)')
axes[2].set_xlabel('QED (higher is better)')
axes[2].set_ylabel('SA (lower is better)')
plt.tight_layout()
plt.show()

In [None]:
# Create molecule object and add hydrogens
mol = Chem.MolFromSmiles(best_smiles)
mol_with_hs = Chem.AddHs(mol)

# Generate 3D coordinates
AllChem.EmbedMolecule(mol_with_hs, randomSeed=42)
AllChem.UFFOptimizeMolecule(mol_with_hs)

# --- Visualization ---
p = py3Dmol.view(width=500, height=500)

# Convert RDKit mol to MOL block format
mblock = Chem.MolToMolBlock(mol_with_hs)
p.addModel(mblock, 'mol')

# Style the view
p.setStyle({'stick':{}})
p.setBackgroundColor('0xeeeeee')
p.zoomTo()
print("\n🔬 3D view of the best generated molecule:")
p.show()

### 6. Conclusion & Final Summary

This project successfully developed a complete pipeline for **graph-based *de novo* molecule design** using an advanced deep learning architecture. By representing molecules as graphs, we were able to capture rich structural information, moving beyond traditional string-based methods.

The core of the project is an **Autoregressive Variational Autoencoder (VAE)**. The **GAT-based Encoder** effectively learned to map complex molecules to a continuous latent space, while the **Autoregressive Decoder** demonstrated its ability to generate new graphs step-by-step.

The development process was iterative and highlighted key challenges in generative chemistry. We progressed from simple models that produced invalid structures to more sophisticated ones that generated syntactically valid but chemically unrealistic molecules ("graph monsters"). The final breakthrough was achieved by combining our trained neural network with a **"smart" constrained generation algorithm**. This hybrid approach, which enforces chemical rules like valence during generation, proved essential for producing valid and connected structures.

#### Key Achievements
* A robust and debugged **autoregressive generative model** was successfully implemented and trained.
* The model demonstrated the ability to generate **novel, valid, and connected** molecules.
* A final analysis using metrics like **QED** and **SA Score** quantitatively confirmed the model's capacity to produce drug-like and synthetically accessible candidates.

#### Future Directions
This work serves as a strong foundation for more advanced research. The clear next steps are:
1.  **Extended Pre-training:** Continuing the supervised training on a larger dataset and for more epochs to further improve the model's understanding of complex chemical patterns.
2.  **Reinforcement Learning (RL):** Implementing an RL fine-tuning loop to optimize the generation process toward specific, user-defined objectives, such as maximizing the docking score against a protein target or achieving a specific property profile.

Overall, this project provides a complete and powerful framework for modern, AI-driven drug discovery.