In [1]:
# =============================================================================
# Project: Generative Chemistry and GNN-based Property Prediction (Final Version)
#
# Objective:
# 1. Generative Model (RNN): Train a model to generate new, valid molecules.
# 2. Predictive Model (GNN): Train a model to predict dipole moment.
# 3. Validation: Use a held-out Test Set to robustly validate the GNN's
#    predictive accuracy against known ground truth values.
# =============================================================================

# --- Step 1: Setup and Installations ---
import subprocess
import sys

def install_packages():
    """Installs required packages using a robust method for PyTorch Geometric."""
    print("--- Checking and installing dependencies ---")
    standard_packages = [
        "rdkit", "pandas", "scikit-learn", "tqdm", "torch",
        "torchvision", "torchaudio", "kagglehub", "selfies"
    ]
    for package in standard_packages:
        try:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
        except subprocess.CalledProcessError:
            print(f"ERROR: Failed to install {package}. Please try installing it manually.")
            sys.exit(1)
    try:
        import torch
        TORCH_VERSION = torch.__version__.split('+')[0]
        CUDA_VERSION = torch.version.cuda
        CUDA_STR = f"cu{CUDA_VERSION.replace('.', '')}" if CUDA_VERSION else 'cpu'
        print(f"Detected PyTorch {TORCH_VERSION} and device type {CUDA_STR}.")
        PYG_URL = f'https://data.pyg.org/whl/torch-{TORCH_VERSION}+{CUDA_STR}.html'
        pyg_packages = ['torch-scatter', 'torch-sparse', 'torch-cluster', 'torch-geometric']
        for package in pyg_packages:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package, '-f', PYG_URL])
        print("--- All dependencies installed successfully. ---")
    except Exception as e:
        print(f"ERROR: Failed to install PyG packages: {e}")
        sys.exit(1)

install_packages()

# --- Step 2: Imports and Global Configuration ---
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from rdkit import Chem
import selfies as sf
import kagglehub

# --- Configuration ---
N_MOLECULES_GNN = 20000
N_MOLECULES_RNN = 50000
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
N_EPOCHS_GNN = 50
N_EPOCHS_RNN = 25
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Step 3: Define Helper Functions ---

def smiles_to_graph(smiles: str):
    """Converts a SMILES string into a PyTorch Geometric Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None
    atom_features = [[
        atom.GetAtomicNum(), atom.GetFormalCharge(), float(atom.GetHybridization()),
        float(atom.GetIsAromatic()), atom.GetTotalNumHs(), atom.GetTotalValence()
    ] for atom in mol.GetAtoms()]
    x = torch.tensor(atom_features, dtype=torch.float)
    if mol.GetNumBonds() > 0:
        row, col = [], []
        for bond in mol.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            row.extend([start, end])
            col.extend([end, start])
        edge_index = torch.tensor([row, col], dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
    return Data(x=x, edge_index=edge_index)

def evaluate_gnn(loader, model, scaler):
    """Evaluates the GNN and returns MAE and accuracy."""
    model.eval()
    predictions, targets = [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(DEVICE)
            out = model(batch)
            predictions.append(out.cpu().numpy())
            targets.append(batch.y.cpu().numpy())
    predictions = np.concatenate(predictions)
    targets = np.concatenate(targets)
    predictions_real = scaler.inverse_transform(predictions.reshape(-1, 1)).flatten()
    targets_real = scaler.inverse_transform(targets.reshape(-1, 1)).flatten()
    mae = np.mean(np.abs(predictions_real - targets_real))
    # Calculate accuracy, avoiding division by zero
    accuracies = 100 * (1 - np.abs(predictions_real - targets_real) / np.where(np.abs(targets_real) > 1e-6, np.abs(targets_real), 1.0))
    return mae, np.mean(accuracies)

def check_pubchem(smiles):
    """Checks if a molecule exists in PubChem. Returns CID if found, else None."""
    url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/smiles/{smiles}/cids/JSON"
    try:
        response = requests.get(url, timeout=5)
        if response.status_code == 200:
            data = response.json()
            return data.get("IdentifierList", {}).get("CID", [None])[0]
    except: # Catch all request exceptions
        return None
    return None

# --- Step 4: Define Model Architectures ---

class SimpleGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim=128, out_dim=1):
        super().__init__()
        nn1 = nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
        self.conv1 = GINConv(nn1)
        nn2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
        self.conv2 = GINConv(nn2)
        self.lin1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.lin2 = nn.Linear(hidden_dim // 2, out_dim)
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_add_pool(x, batch)
        x = F.relu(self.lin1(x))
        return self.lin2(x).view(-1)

class SELFIES_RNN(nn.Module):
    def __init__(self, vocab_size, emb_size=128, hidden_size=512, num_layers=3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=0)
        self.rnn = nn.LSTM(emb_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.fc = nn.Linear(hidden_size, vocab_size)
    def forward(self, x, hidden=None):
        emb = self.embedding(x)
        out, hidden = self.rnn(emb, hidden)
        return self.fc(out), hidden

def sample_selfies(model, token2idx, idx2token, max_len=50, temperature=1.0):
    model.eval()
    start_token = '[C]'
    x = torch.tensor([[token2idx[start_token]]], device=DEVICE)
    hidden = None
    tokens = [start_token]
    for _ in range(max_len):
        out, hidden = model(x, hidden)
        probs = F.softmax(out.squeeze() / temperature, dim=-1)
        idx = torch.multinomial(probs, 1).item()
        if idx == 0: break
        tokens.append(idx2token[idx])
        x = torch.tensor([[idx]], device=DEVICE)
    try:
        return sf.decoder(''.join(tokens))
    except:
        return None

# =============================================================================
# Main Execution Block
# =============================================================================
if __name__ == "__main__":
    print(f"\nProject starting on device: {DEVICE}\n")
    path = kagglehub.dataset_download("nikitamanaenkov/qm40-molecular-qm-dataset")
    df_main = pd.read_csv(os.path.join(path, "main.csv"))

    print("\n--- 1. Preparing Data for GNN ---")
    gnn_data_list = []
    subset_df_gnn = df_main.head(N_MOLECULES_GNN)
    for _, row in tqdm(subset_df_gnn.iterrows(), total=subset_df_gnn.shape[0], desc="Creating GNN graphs"):
        graph = smiles_to_graph(row['smile'])
        if graph:
            graph.y = torch.tensor([row['dipol_mom']], dtype=torch.float)
            graph.smiles = row['smile'] # Store SMILES for later comparison
            gnn_data_list.append(graph)

    train_val_data, test_data = train_test_split(gnn_data_list, test_size=0.1, random_state=42)
    train_data, val_data = train_test_split(train_val_data, test_size=0.11, random_state=42) # 0.11 * 0.9 = ~0.1

    train_targets = np.array([d.y.item() for d in train_data]).reshape(-1, 1)
    scaler = StandardScaler().fit(train_targets)
    for d in train_val_data + test_data:
        d.y = torch.tensor(scaler.transform([[d.y.item()]])[0], dtype=torch.float)

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    print("\n--- 2. Training Predictive GNN ---")
    gnn_model = SimpleGNN(in_dim=train_data[0].num_node_features).to(DEVICE)
    optimizer_gnn = torch.optim.Adam(gnn_model.parameters(), lr=LEARNING_RATE)
    loss_fn_gnn = nn.MSELoss()
    best_val_mae = float('inf')
    for epoch in range(1, N_EPOCHS_GNN + 1):
        gnn_model.train()
        for batch in train_loader:
            batch = batch.to(DEVICE)
            optimizer_gnn.zero_grad()
            out = gnn_model(batch)
            loss = loss_fn_gnn(out, batch.y)
            loss.backward()
            optimizer_gnn.step()
        val_mae, _ = evaluate_gnn(val_loader, gnn_model, scaler)
        if val_mae < best_val_mae:
            best_val_mae = val_mae
            torch.save(gnn_model.state_dict(), 'best_gnn_model.pth')
        print(f"GNN Epoch {epoch:02d} | Val MAE: {val_mae:.4f} Debye")

    print("\n--- 3. Final GNN Evaluation on Unseen Test Set ---")
    gnn_model.load_state_dict(torch.load('best_gnn_model.pth'))
    test_mae, test_accuracy = evaluate_gnn(test_loader, gnn_model, scaler)
    print(f"GNN Performance on Test Set:")
    print(f"  - Mean Absolute Error (MAE): {test_mae:.4f} Debye")
    print(f"  - Average Accuracy:          {test_accuracy:.2f}%")

    print("\n--- 4. Training Generative RNN ---")
    selfies_list = [sf.encoder(smi) for smi in tqdm(df_main['smile'].head(N_MOLECULES_RNN), desc="Encoding to SELFIES") if smi and sf.encoder(smi)]
    all_tokens = set(t for s in selfies_list for t in sf.split_selfies(s))
    token2idx = {t: i + 1 for i, t in enumerate(sorted(all_tokens))}; token2idx['<PAD>'] = 0
    idx2token = {i: t for t, i in token2idx.items()}; vocab_size = len(token2idx)
    max_len = max(len(list(sf.split_selfies(s))) for s in selfies_list)
    selfies_tensor = torch.stack([torch.tensor([token2idx.get(t, 0) for t in list(sf.split_selfies(s))] + [0] * (max_len - len(list(sf.split_selfies(s)))), dtype=torch.long) for s in selfies_list])
    rnn_dataset = torch.utils.data.TensorDataset(selfies_tensor[:, :-1], selfies_tensor[:, 1:])
    rnn_loader = torch.utils.data.DataLoader(rnn_dataset, batch_size=128, shuffle=True)
    rnn_model = SELFIES_RNN(vocab_size).to(DEVICE)
    opt_rnn = torch.optim.Adam(rnn_model.parameters(), lr=1e-3)
    for epoch in range(1, N_EPOCHS_RNN + 1):
        for x, y in rnn_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt_rnn.zero_grad()
            out, _ = rnn_model(x)
            loss = F.cross_entropy(out.reshape(-1, vocab_size), y.reshape(-1), ignore_index=0)
            loss.backward()
            opt_rnn.step()
        print(f"RNN Epoch {epoch:02d}, Loss: {loss.item():.4f}")

    print("\n--- 5. Generating and Analyzing Novel Molecules ---")
    generated_smiles = [sample_selfies(rnn_model, token2idx, idx2token, temperature=0.9) for _ in tqdm(range(200), desc="Generating Molecules")]
    valid_smiles = [s for s in generated_smiles if s and Chem.MolFromSmiles(s)]

    novel_molecules_smiles = []
    existing_molecules = []
    for s in tqdm(valid_smiles, desc="Checking Novelty"):
        cid = check_pubchem(s)
        if cid is None:
            novel_molecules_smiles.append(s)
        else:
            existing_molecules.append({'smiles': s, 'cid': cid})


    print(f"\nGenerated {len(valid_smiles)} valid molecules.")
    print(f"Found {len(novel_molecules_smiles)} potentially novel molecules (not in PubChem).")
    print(f"Found {len(existing_molecules)} generated molecules that exist in PubChem.")


    if novel_molecules_smiles:
        novel_graphs = [smiles_to_graph(s) for s in novel_molecules_smiles]
        novel_graphs = [g for g in novel_graphs if g is not None]
        if novel_graphs:
            predict_loader = DataLoader(novel_graphs, batch_size=len(novel_graphs))
            gnn_model.eval()
            with torch.no_grad():
                batch = next(iter(predict_loader)).to(DEVICE)
                preds_scaled = gnn_model(batch).cpu().numpy()
                preds_real = scaler.inverse_transform(preds_scaled.reshape(-1, 1)).flatten()

            print("\n--- Predictions for Top 10 Novel Molecules ---")
            for i, smi in enumerate(novel_molecules_smiles[:10]):
                print(f"  - SMILES: {smi:<45} | Predicted Dipole: {preds_real[i]:.4f} Debye")

    # --- 6. GNN Performance Spotlight: Comparing Predictions to Ground Truth ---
    print("\n--- 6. GNN Performance Spotlight: Comparing Predictions to Ground Truth ---")
    print("Taking a random sample of 15 molecules from the unseen Test Set to demonstrate model accuracy.")

    sample_size = min(15, len(test_data))
    sample_indices = np.random.choice(len(test_data), sample_size, replace=False)
    sample_data = [test_data[i] for i in sample_indices]
    sample_loader = DataLoader(sample_data, batch_size=sample_size)

    gnn_model.eval()
    with torch.no_grad():
        batch = next(iter(sample_loader))
        batch = batch.to(DEVICE)
        predictions_scaled = gnn_model(batch).cpu().numpy()
        predicted_dipoles = scaler.inverse_transform(predictions_scaled.reshape(-1, 1)).flatten()
        actual_dipoles_scaled = batch.y.cpu().numpy()
        actual_dipoles = scaler.inverse_transform(actual_dipoles_scaled.reshape(-1, 1)).flatten()

    comparison_results = []
    for i in range(sample_size):
        smiles = sample_data[i].smiles
        predicted = predicted_dipoles[i]
        actual = actual_dipoles[i]
        error = abs(predicted - actual)
        comparison_results.append([smiles, f"{predicted:.4f}", f"{actual:.4f}", f"{error:.4f}"])

    comparison_df = pd.DataFrame(comparison_results, columns=["SMILES", "GNN Predicted (D)", "Actual Value (D)", "Absolute Error"])
    print("\n--- Comparison Table: Predicted vs. Actual Dipole Moment ---")

    try:
        from IPython.display import display, HTML
        display(HTML(comparison_df.to_html(index=False, justify='left')))
    except ImportError:
        print(comparison_df.to_string(index=False))

    print("\n" + "="*60)
    print("                PROJECT COMPLETE")
    print("="*60)

--- Checking and installing dependencies ---
Detected PyTorch 2.8.0 and device type cu126.
--- All dependencies installed successfully. ---

Project starting on device: cuda

Using Colab cache for faster access to the 'qm40-molecular-qm-dataset' dataset.

--- 1. Preparing Data for GNN ---


Creating GNN graphs: 100%|██████████| 20000/20000 [00:18<00:00, 1079.02it/s]



--- 2. Training Predictive GNN ---
GNN Epoch 01 | Val MAE: 1.2018 Debye
GNN Epoch 02 | Val MAE: 1.2395 Debye
GNN Epoch 03 | Val MAE: 1.1820 Debye
GNN Epoch 04 | Val MAE: 1.1677 Debye
GNN Epoch 05 | Val MAE: 1.1556 Debye
GNN Epoch 06 | Val MAE: 1.1290 Debye
GNN Epoch 07 | Val MAE: 1.1486 Debye
GNN Epoch 08 | Val MAE: 1.1556 Debye
GNN Epoch 09 | Val MAE: 1.1182 Debye
GNN Epoch 10 | Val MAE: 1.1102 Debye
GNN Epoch 11 | Val MAE: 1.1088 Debye
GNN Epoch 12 | Val MAE: 1.0912 Debye
GNN Epoch 13 | Val MAE: 1.0801 Debye
GNN Epoch 14 | Val MAE: 1.0758 Debye
GNN Epoch 15 | Val MAE: 1.0625 Debye
GNN Epoch 16 | Val MAE: 1.0513 Debye
GNN Epoch 17 | Val MAE: 1.1580 Debye
GNN Epoch 18 | Val MAE: 1.0682 Debye
GNN Epoch 19 | Val MAE: 1.0400 Debye
GNN Epoch 20 | Val MAE: 1.0545 Debye
GNN Epoch 21 | Val MAE: 1.0476 Debye
GNN Epoch 22 | Val MAE: 1.0520 Debye
GNN Epoch 23 | Val MAE: 1.0488 Debye
GNN Epoch 24 | Val MAE: 1.0478 Debye
GNN Epoch 25 | Val MAE: 1.0389 Debye
GNN Epoch 26 | Val MAE: 1.0283 Debye
GN

Encoding to SELFIES: 100%|██████████| 50000/50000 [00:23<00:00, 2126.40it/s]


RNN Epoch 01, Loss: 1.1319
RNN Epoch 02, Loss: 0.9718
RNN Epoch 03, Loss: 0.9331
RNN Epoch 04, Loss: 0.8816
RNN Epoch 05, Loss: 0.8967
RNN Epoch 06, Loss: 0.8519
RNN Epoch 07, Loss: 0.7794
RNN Epoch 08, Loss: 0.8007
RNN Epoch 09, Loss: 0.7817
RNN Epoch 10, Loss: 0.7662
RNN Epoch 11, Loss: 0.7476
RNN Epoch 12, Loss: 0.7537
RNN Epoch 13, Loss: 0.7375
RNN Epoch 14, Loss: 0.7389
RNN Epoch 15, Loss: 0.7259
RNN Epoch 16, Loss: 0.6947
RNN Epoch 17, Loss: 0.7147
RNN Epoch 18, Loss: 0.7435
RNN Epoch 19, Loss: 0.6870
RNN Epoch 20, Loss: 0.6889
RNN Epoch 21, Loss: 0.6874
RNN Epoch 22, Loss: 0.6706
RNN Epoch 23, Loss: 0.6701
RNN Epoch 24, Loss: 0.6575
RNN Epoch 25, Loss: 0.6557

--- 5. Generating and Analyzing Novel Molecules ---


Generating Molecules: 100%|██████████| 200/200 [00:07<00:00, 28.26it/s]
Checking Novelty: 100%|██████████| 200/200 [01:52<00:00,  1.77it/s]


Generated 200 valid molecules.
Found 29 potentially novel molecules (not in PubChem).
Found 171 generated molecules that exist in PubChem.

--- Predictions for Top 10 Novel Molecules ---
  - SMILES: CS[C@H1](C)C(=O)O[C@@H1](C)C1=CN=CS1C(C2)CCC2=O | Predicted Dipole: 3.5528 Debye
  - SMILES: CC1=CC=C(COC(=O)CCCF)S1N2CCOC3(CCOCC3)C2=O    | Predicted Dipole: 3.0698 Debye
  - SMILES: CO/N=C/C(=O)N1C[C@@H1](O)C[C@@H1]1COCC(F)(F)F | Predicted Dipole: 3.5972 Debye
  - SMILES: CCO/C=C\C(=O)N[C@H1]1COC[C@H1]1OC(C)(C)CC(F)(F)F | Predicted Dipole: 3.5781 Debye
  - SMILES: CC[C@@H1](C)C1=NC(C2CCC2)=CS1C(=O)NC3=CC=CC=C3C#N | Predicted Dipole: 4.0921 Debye
  - SMILES: COC(=O)COC(=O)C1=CC=C(Cl)S1C2=NOC(C(C)(C)C)=N2 | Predicted Dipole: 3.2561 Debye
  - SMILES: C[S@](=O)CC1NCC2=CC=C(Cl)S2C1=O               | Predicted Dipole: 3.8233 Debye
  - SMILES: C1SC=CC=CC1(CNC(=O)N[C@H1](C)C2=NC=C(C3=CC=CC=C3Cl)O2)C#N | Predicted Dipole: 4.5208 Debye
  - SMILES: CSCCC(=O)N[C@H1]1CCCOCC12CC2C3CC3C4=CC=CC=C4C#N | 




SMILES,GNN Predicted (D),Actual Value (D),Absolute Error
CCCNC(=O)CSCCCC(C)C,2.8614,2.8284,0.033
CCOC(=O)CC[C@@H](C)NCC(F)F,2.4735,1.4859,0.9876
COCC1=CCN(C(=O)Cc2ccsc2)CC1,3.3869,4.7323,1.3454
CON(C)C(=O)N[C@@H](C(C)C)C1CC1,2.5698,3.2904,0.7206
C=C(Cl)CN1CC[C@@H](C)C[C@H]1C(N)=O,3.0988,4.429,1.3302
Cc1cc(C(=O)NCC[C@H](C)F)no1,3.9091,2.7445,1.1646
C[C@@H](O)C[C@H](C)CNC(=S)NC1CC1,4.5577,5.1061,0.5484
C[C@H](CO)CS[C@H]1CCC[C@H](C)C1,1.9041,2.3042,0.4001
CC[C@@H](C)C[C@@H](C)NC(=O)CCS(=O)(=O)CC(C)C,3.8611,4.4998,0.6387
CCONC(=O)N[C@@H]1COC[C@H]1OC,3.034,1.7316,1.3024



                PROJECT COMPLETE
