# 📓 Molecule Generator with VAE-based Graph Neural Network (VAE-GNN)
This notebook implements a molecule generator based on a Variational Autoencoder (VAE) and Graph Neural Network (GNN), leveraging scaffold-conditioning for improved molecular generation.

## 🔧 Imports 

In [None]:
import torch
from torch import nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
import torch.nn.functional as F
import networkx as nx
from rdkit import Chem
from rdkit.Chem import Draw
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

## 🧪 Data Preparation: Molecule to Graph Conversion

In [None]:
def mol_to_graph(mol):
    nodes = []
    edge_index = []

    for atom in mol.GetAtoms():
        nodes.append([atom.GetAtomicNum()])

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(nodes, dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

## 🔬 Graph Network Layer (from scratch)

In [None]:
class GraphVAE(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=4, dropout=0.1):
        super(GraphVAE, self).__init__(aggr='add')
        self.heads = heads
        self.out_channels = out_channels
        
        self.lin_q = nn.Linear(in_channels, heads * out_channels)
        self.lin_k = nn.Linear(in_channels, heads * out_channels)
        self.lin_v = nn.Linear(in_channels, heads * out_channels)
        self.lin_out = nn.Linear(heads * out_channels, out_channels)
        
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x, edge_index):
        Q = self.lin_q(x).view(-1, self.heads, self.out_channels)
        K = self.lin_k(x).view(-1, self.heads, self.out_channels)
        V = self.lin_v(x).view(-1, self.heads, self.out_channels)

        out = self.propagate(edge_index, Q=Q, K=K, V=V)
        out = self.lin_out(out.view(-1, self.heads * self.out_channels))
        out = self.norm(out + x)
        return out

    def message(self, Q_i, K_j, V_j):
        score = (Q_i * K_j).sum(dim=-1, keepdim=True) / (self.out_channels ** 0.5)
        score = torch.softmax(score, dim=1)
        score = self.dropout(score)
        return score * V_j

## 🏗️ Molecule Generator with VAE and Scaffold Conditioning

In [None]:
class VAE_GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, latent_dim=16, num_layers=4):
        super(VAE_GNN, self).__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(GraphVAE(in_channels, hidden_channels))
            in_channels = hidden_channels
        
        # VAE specific layers
        self.fc_mu = nn.Linear(hidden_channels, latent_dim)
        self.fc_logvar = nn.Linear(hidden_channels, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, hidden_channels)
        self.fc_out = nn.Linear(hidden_channels, 1)  

    def encode(self, x, edge_index):
        for layer in self.layers:
            x = layer(x, edge_index)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_decode(z))
        out = self.fc_out(h)
        return out

    def forward(self, x, edge_index):
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z)


## 🚀 Training and Sampling Pipeline

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy_with_logits(recon_x, x, reduction='sum')
    # Kingma & Welling (2013) VAE paper
    # β is the weight of the KL divergence term
    # Higher β means more regularization and smoother latent space
    # Lower β means more freedom for the latent space
    β = 1.0
    # KL divergence term
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + β * KL

def train(model, loader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out, data.x, data.mu, data.logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")

## 🔬 Inference & Visualization

In [None]:
def generate_molecule(model, num_samples=5):
    model.eval()
    samples = []
    with torch.no_grad():
        for _ in range(num_samples):
            z = torch.randn(1, model.fc_mu.out_features).to(model.fc_mu.weight.device)
            sampled_molecule = model.decode(z)
            samples.append(sampled_molecule)
    return samples

def visualize_molecule(molecule_data):
    """
    Visualize the generated molecule using RDKit. 
    The `molecule_data` is expected to contain atom and bond information in a graph format.
    
    Parameters:
    molecule_data (list): Atom data or a graph object representing the molecule.
    """
    mol = Chem.RWMol()  
    
 
    for atom_data in molecule_data['atoms']:  
        atom = Chem.Atom(atom_data['atomic_num'])  
        mol.AddAtom(atom)
    
    for bond_data in molecule_data['bonds']:  
        bond = Chem.Bond(bond_data['bond_type'])
        mol.AddBond(bond_data['start_idx'], bond_data['end_idx'], bond.GetBondType())
    
    img = Draw.MolToImage(mol)  
    display(img)  



In [None]:
# Example SMILES

smiles_list = pd.read_csv('smiles.csv')['SMILES'].tolist()

graph_data = mol_to_graph(smiles_list)

data_loader = DataLoader(graph_data, batch_size=2, shuffle=True)

latent_dim = 32
model = GraphVAE(in_channels=9, hidden_channels=128, latent_dim=latent_dim) 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

train(model, data_loader, optimizer, epochs=10)

generate_molecule(model, "[*:0]NC1CN2CCC1CC2")  