### De Novo Design of EGFR Inhibitors using a Deep Generative Model

* **Author:** Hosein Mohammadi
* **Date:** July 2024
* **Contact:** [huseinmohammadi83@gmail.com](mailto:huseinmohammadi83@gmail.com)
* **LinkedIn:** [Hosein Mohammadi](https://www.linkedin.com/in/hosein-mohammadi-979b8a2b2/)
* **Project Repository:** [VAE-RL-Drug-Design](https://github.com/Hosein541/VAE-RL-Drug-Design)
---

### 1. Project Overview

This notebook implements an end-to-end pipeline for de novo drug design, a key process in computational chemistry and drug discovery. The primary objective is to develop and optimize a deep generative model capable of designing novel molecular structures. These molecules are specifically targeted to inhibit the **Epidermal Growth Factor Receptor (EGFR) kinase domain**, a crucial target in modern cancer therapy.

The project leverages a series of advanced techniques to ensure the generation of valid, unique, and drug-like molecules with desirable properties.

---

### 1. Environment Setup & Library Imports

**Objective:**
This initial cell prepares the notebook's environment. It installs all required external libraries and then imports all necessary modules that will be used throughout the project.

**Methodology:**
1.  **Installation:** The `pip` command is used to install three core external libraries:
    * `chembl_webresource-client`: To programmatically access and download data from the ChEMBL database.
    * `rdkit-pypi`: The primary toolkit for cheminformatics, used for processing molecular structures and calculating properties.
    * `selfies`: A library for the SELFIES (SELF-referencIng Embedded Strings) representation of molecules, which guarantees 100% valid structures during generation.

2.  **Imports:** The required modules are imported and grouped by category: standard Python libraries, data science/visualization tools, cheminformatics libraries, and the PyTorch deep learning framework.

3.  **SA_Score Configuration:** The path to RDKit's Synthetic Accessibility (SA) score module is explicitly added to the system path to ensure it can be imported and used.

**Outcome:**
Upon successful execution, all dependencies are loaded, and a confirmation message is printed. The notebook is now ready for the data acquisition and processing steps.

In [None]:
# install neccesary libraries
!pip install chembl_webresource-client
!pip install rdkit-pypi
!pip install selfies


# Section 1: Standard Python Libraries
import os
import sys
from pathlib import Path
from typing import List


# Section 2: Core (Third-party) Libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import selfies as sf

# --- Chemoinformatics Libraries (RDKit & ChEMBL) ---
from chembl_webresource_client.new_client import new_client
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, QED, RDConfig
# Add SA_Score module to Python path for synthesis accessibility calculation
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer


# Section 3: Deep Learning Libraries (PyTorch)

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

print("✅ All required libraries imported successfully.")

### 2. Data Acquisition and Preprocessing from ChEMBL

**Objective:**
This cell downloads, cleans, and processes the bioactivity data for our target protein, EGFR, from the ChEMBL database. The goal is to create a clean dataset of molecules and their corresponding inhibitory potencies (IC50 values) that will be used to train the generative model.

**Methodology:**
1.  **Data Retrieval:** The code connects to the ChEMBL database using its API client. It targets the **Epidermal Growth Factor Receptor (EGFR)** using its unique ChEMBL ID, `CHEMBL203`. It then fetches up to 30,000 associated activity records, specifically requesting the molecule's structure (`canonical_smiles`) and its `standard_type` and `standard_value`.
2.  **Data Cleaning & Filtering:** The raw data is converted into a pandas DataFrame and undergoes several crucial cleaning steps:
    * The dataset is filtered to keep only records where the activity is measured as **IC50**, a standard measure of a drug's potency.
    * Any rows that have missing values for the IC50 or the SMILES string are removed to ensure data quality.
    * All IC50 values are ensured to be positive numbers, as required for the subsequent logarithmic conversion.
3.  **Feature Engineering (pIC50):** A new feature, **pIC50**, is calculated from the IC50 values. The pIC50 is the negative base-10 logarithm of the IC50 value (in Molar concentration). This transformation is standard practice as it converts the data to a more convenient logarithmic scale where higher values indicate greater potency, making it more suitable for machine learning models.

**Outcome:**
The final, cleaned dataset is saved to `egfr_ic50_dataset.csv`. This file contains the essential information for each molecule—its SMILES string, IC50 value, and the calculated pIC50—and serves as the foundational dataset for the next steps.

In [None]:
# --- Settings ---
target_id = "CHEMBL203"
max_records = 30000

activity = new_client.activity

print("⏳ Fetching data from ChEMBL...")

# EGFR data generator
activity_gen = activity.filter(
    target_chembl_id=target_id
).only(
    [
        "canonical_smiles",
        "standard_type",
        "standard_value",
        "molecule_chembl_id",
        "activity_id",
    ]
)

# Manually retrieve data up to max_records
records = []
for i, record in enumerate(activity_gen):
    if i >= max_records:
        break
    records.append(record)
    if i % 500 == 0:
        print(f"🔄 Received record {i}")

print(f"🎯 Number of records received: {len(records)}")

# --- Convert to DataFrame and Preprocessing ---
df = pd.DataFrame(records)

# Filter IC50 and numerical values
df = df[df["standard_type"] == "IC50"]
df = df[df["standard_value"].notna()]
df = df[df["canonical_smiles"].notna()]
df["standard_value"] = pd.to_numeric(df["standard_value"], errors="coerce")
df = df[df["standard_value"] > 0]

# Calculate pIC50
df["pIC50"] = -np.log10(df["standard_value"] * 1e-9)

# Final output
final_df = df[["molecule_chembl_id", "canonical_smiles", "standard_value", "pIC50"]]
final_df.columns = ["chembl_id", "smiles", "IC50_nM", "pIC50"]
final_df.to_csv("egfr_ic50_dataset.csv", index=False)

print("✅ File egfr_ic50_dataset.csv saved.")
print(f"📦 Final records: {final_df.shape[0]}")

### 3. Data Preparation and SMILES Tokenization

**Objective:**
This cell takes the processed dataset and prepares the SMILES strings for input into the neural network. This involves creating a character-level vocabulary, converting the strings into sequences of integer tokens, and padding them to a uniform length to be batched into a tensor.

**Methodology:**
1.  **Data Loading:** The `egfr_ic50_dataset.csv` file is loaded into a pandas DataFrame, and the `smiles` column is extracted into a list. This list is also saved to `train.smi`, a common file format for molecular datasets.
2.  **Vocabulary Creation:** A character-level vocabulary is built by finding all unique characters present across all SMILES strings in the dataset. Four special tokens are added to this vocabulary:
    * `<PAD>`: For padding sequences to the same length.
    * `<START>`: To signify the beginning of a sequence.
    * `<END>`: To signify the end of a sequence.
    * `<UNK>`: To represent any unknown characters.
    
    Mappings from characters to integer indices (`char2idx`) and vice-versa (`idx2char`) are then created.
3.  **Tokenization & Padding:** Each SMILES string is converted into a numerical sequence. This is done by:
    * Adding `<START>` and `<END>` tokens to the beginning and end of each string.
    * Mapping every character to its corresponding integer index from the `char2idx` dictionary.
    * Padding every sequence with the `<PAD>` token's index until it reaches the `max_length` (defined by the longest molecule in the dataset). This ensures all input sequences have a uniform size, which is necessary for batch processing in PyTorch.

**Outcome:**
The final output is `encoded_tensor`, a PyTorch LongTensor. Each row in this tensor represents a molecule, and the numbers correspond to the character indices in the vocabulary. This tensor is the final prepared input, ready to be used for training the generative model.

In [None]:
# Load the dataset
df = pd.read_csv("egfr_ic50_dataset.csv")

# We only use the SMILES column for reconstruction
smiles_list = df['smiles'].tolist()
print(f"Total SMILES count: {len(smiles_list)}")

with open('train.smi', 'w', encoding='utf-8') as f:
    f.write('\n'.join(smiles_list))


# Define special tokens
special_tokens = ['<PAD>', '<START>', '<END>', '<UNK>']

# Extract all unique characters from SMILES
charset = set(''.join(smiles_list))
charset = sorted(list(charset))

# Build the final vocabulary
vocab = special_tokens + charset
char2idx = {ch: idx for idx, ch in enumerate(vocab)}
idx2char = {idx: ch for ch, idx in char2idx.items()}

vocab_size = len(vocab)
print(f"Total number of characters (vocab size): {vocab_size}")


# Set maximum length (e.g., based on the 95th percentile of lengths)
max_length = max(len(s) for s in smiles_list) + 2  # +2 for <START> and <END>
print(f"Selected sequence length: {max_length}")

def tokenize_smiles(smi):
    tokens = ['<START>'] + list(smi) + ['<END>']
    idxs = [char2idx.get(ch, char2idx['<UNK>']) for ch in tokens]
    # Padding
    padding = [char2idx['<PAD>']] * (max_length - len(idxs))
    return idxs + padding

encoded_smiles = [tokenize_smiles(smi) for smi in smiles_list]
encoded_tensor = torch.tensor(encoded_smiles, dtype=torch.long)

### 4. VAE Model Training with Curriculum Learning

**Objective:**
This cell defines and trains the core generative model, a Variational Autoencoder (VAE) based on SELFIES. It employs a **Curriculum Learning** strategy to guide the model toward generating more desirable, drug-like molecules.

**Methodology:**
This process is divided into several logical parts within the code:

1.  **Configuration:** Key hyperparameters for the model and training process are defined, including learning rate (`LR`), batch size (`BATCH_SIZE`), and the number of epochs for pre-training (`PRE_EPOCHS`) and curriculum fine-tuning (`CURR_EPOCHS`).
2.  **Helper Classes & Functions:**
    * `SelfiesTokenizer`: Converts SELFIES strings into integer tensors for the model.
    * `SelfiesDataset`: A PyTorch `Dataset` class to handle the data loading.
    * `VAE`: The neural network architecture, featuring a bidirectional GRU encoder and a GRU decoder. It learns a compressed latent space representation (`z`) of the molecules.
    * `loss_fn`: Calculates the VAE's loss, which is a combination of the reconstruction loss (how well the model can recreate an input molecule) and the KL divergence (which regularizes the latent space).
    * `keep`: A helper function that defines our criteria for a "high-quality" molecule based on QED, SA score, and molecular weight (MW).
3.  **Training Workflow (Curriculum Learning):**
    * **Pre-training:** The VAE model is first trained for `PRE_EPOCHS` on the **entire dataset**. This allows the model to learn the general grammar and patterns of the whole chemical space.
    * **Filtering:** The dataset is then filtered using the `keep` function, creating a smaller, higher-quality subset of molecules that meet our drug-like criteria.
    * **Fine-tuning:** The pre-trained model is then trained for an additional `CURR_EPOCHS` exclusively on this **filtered dataset**. This second stage fine-tunes the model, biasing it to generate new molecules that share the desirable properties of the high-quality subset.

**Outcome:**
The cell trains the VAE model and saves its learned weights and vocabulary to a file named `vae_curriculum.pt`. This model is now capable of generating novel, valid molecules and is primed to produce a higher proportion of drug-like compounds compared to a model trained without curriculum learning.

In [None]:
PRE_EPOCHS, CURR_EPOCHS = 15, 35
BATCH_SIZE, EMB, HID, LAT = 512, 128, 256, 128
LR, FREE_BITS = 1e-3, 1.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QED_MIN, SA_MAX, MW_MIN, MW_MAX = 0.5, 5.8, 150, 600

class SelfiesTokenizer:
    def __init__(self, selfies: List[str]):
        specials = ["<PAD>", "<BOS>", "<EOS>"]
        vocab = sorted({t for s in selfies for t in sf.split_selfies(s)})
        self.idx2tok = specials + vocab
        self.tok2idx = {t: i for i, t in enumerate(self.idx2tok)}
        self.pad = self.tok2idx["<PAD>"]
        self.bos = self.tok2idx["<BOS>"]
        self.eos = self.tok2idx["<EOS>"]
        self.pad_idx = self.pad
        self.bos_idx = self.bos
        self.eos_idx = self.eos
    def encode(self, s: str) -> torch.LongTensor:
        ids = [self.bos] + [self.tok2idx[t] for t in sf.split_selfies(s)] + [self.eos]
        return torch.tensor(ids, dtype=torch.long)

class SelfiesDataset(Dataset):
    def __init__(self, selfies: List[str], tok: SelfiesTokenizer):
        self.data, self.tok = selfies, tok
    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.tok.encode(self.data[i])

def collate(batch, pad):
    batch = [torch.as_tensor(x) for x in batch]
    lens = torch.tensor([len(x) for x in batch])
    padded = pad_sequence(batch, batch_first=True, padding_value=pad)
    return padded[:, :-1], padded[:, 1:], lens - 1

class VAE(nn.Module):
    def __init__(self, vocab, pad, bos):
        super().__init__()
        self.pad, self.bos = pad, bos
        self.emb = nn.Embedding(vocab, EMB, padding_idx=pad)
        self.encoder = nn.GRU(EMB, HID, 2, batch_first=True, bidirectional=True, dropout=0.3)
        self.mu = nn.Linear(HID * 2, LAT)
        self.logvar = nn.Linear(HID * 2, LAT)
        self.decoder = nn.GRU(EMB + LAT, HID, 2, batch_first=True, dropout=0.3)
        self.out = nn.Linear(HID, vocab)
        self.out_proj = self.out
    def encode(self, x, l):
        _, h = self.encoder(nn.utils.rnn.pack_padded_sequence(self.emb(x), l.cpu(), batch_first=True, enforce_sorted=False))
        h = torch.cat([h[-2], h[-1]], -1)
        return self.mu(h), self.logvar(h)
    @staticmethod
    def reparam(mu, logvar):
        return mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar)
    def decode(self, z, tgt, tfr):
        B, T = tgt.shape
        outputs, hidden = [], None
        tok = torch.full((B,), self.bos, device=tgt.device, dtype=torch.long)
        for t in range(T):
            dec_in = torch.cat([self.emb(tok), z], -1).unsqueeze(1)
            out, hidden = self.decoder(dec_in, hidden)
            logits = self.out(out.squeeze(1))
            logits[:, self.bos] = -1e9
            outputs.append(logits)
            tok = torch.where(torch.rand(B, device=tgt.device) < tfr, tgt[:, t], logits.argmax(-1))
        return torch.stack(outputs, 1)
    def forward(self, inp, tgt, l, tfr):
        mu, logvar = self.encode(inp, l)
        z = self.reparam(mu, logvar)
        return self.decode(z, tgt, tfr), mu, logvar

def kl_div(mu, logvar): return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1)
def cyc_beta(step, total, cycles=4, max_beta=4.0):
    return max_beta * ((step % (total // cycles)) / (total // cycles))
def loss_fn(logits, tgt, mu, logvar, beta, pad):
    recon = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), ignore_index=pad)
    kl = kl_div(mu, logvar).clamp(min=FREE_BITS).mean()
    return recon + beta * kl, recon, kl

def keep(smi: str):
    mol = Chem.MolFromSmiles(smi)
    if mol is None: return False
    q, sa, mw = QED.qed(mol), sascorer.calculateScore(mol), Descriptors.MolWt(mol)
    return q >= QED_MIN and sa <= SA_MAX and MW_MIN <= mw <= MW_MAX

def train_epoch(model, loader, opt, step, total):
    model.train(); r_sum = k_sum = 0
    for inp, tgt, lens in loader:
        inp, tgt, lens = inp.to(DEVICE), tgt.to(DEVICE), lens.to(DEVICE)
        beta = cyc_beta(step, total)
        tfr = max(0.3, 1 - step / (0.25 * total))
        logits, mu, logvar = model(inp, tgt, lens, tfr)
        loss, recon, kl = loss_fn(logits, tgt, mu, logvar, beta, model.pad)
        opt.zero_grad(); loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        r_sum += recon.item(); k_sum += kl.item(); step += 1
    return r_sum / len(loader), k_sum / len(loader), step

def sample(model, tokenizer, *, max_len=150, n=32, device="cpu"):
    model.eval()
    with torch.no_grad():
        z = torch.randn(n, LAT, device=device)
        seq = torch.full((n, 1), tokenizer.bos_idx, dtype=torch.long, device=device)
        hidden = None
        for _ in range(max_len):
            dec_in = torch.cat([model.emb(seq[:, -1]), z], -1).unsqueeze(1)
            dec_out, hidden = model.decoder(dec_in, hidden)
            logits = model.out_proj(dec_out.squeeze(1))
            logits[:, tokenizer.bos_idx] = -1e9
            next_tok = torch.multinomial(F.softmax(logits, -1), 1)
            seq = torch.cat([seq, next_tok], 1)
            if (next_tok == tokenizer.eos_idx).all(): break
    selfies = ["".join(tokenizer.idx2tok[i] for i in s[1:] if i not in (tokenizer.pad_idx, tokenizer.eos_idx)) for s in seq]
    return [sf.decoder(se) if sf.is_valid_selfies(se) else "" for se in selfies]

if __name__ == "__main__":
    src = Path("train.smi"); assert src.exists()
    smiles = [l.strip() for l in src.read_text().splitlines() if l.strip() and Chem.MolFromSmiles(l.strip())]
    selfies_all = [sf.encoder(s) for s in smiles]
    tok = SelfiesTokenizer(selfies_all)
    loader_all = DataLoader(SelfiesDataset(selfies_all, tok), BATCH_SIZE, True, collate_fn=lambda b: collate(b, tok.pad))
    model = VAE(len(tok.idx2tok), tok.pad, tok.bos).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)
    total_steps, step = (PRE_EPOCHS + CURR_EPOCHS) * len(loader_all), 0
    for e in range(1, PRE_EPOCHS + 1):
        r, k, step = train_epoch(model, loader_all, opt, step, total_steps)
        print(f"[PRE {e}/{PRE_EPOCHS}] recon {r:.3f} kl {k:.3f}")
    smiles_filt = [s for s in smiles if keep(s)]
    Path("filtered.smi").write_text("\n".join(smiles_filt))
    selfies_filt = [sf.encoder(s) for s in smiles_filt]
    loader_filt = DataLoader(SelfiesDataset(selfies_filt, tok), BATCH_SIZE, True, collate_fn=lambda b: collate(b, tok.pad))
    for e in range(1, CURR_EPOCHS + 1):
        r, k, step = train_epoch(model, loader_filt, opt, step, total_steps)
        print(f"[CURR {e}/{CURR_EPOCHS}] recon {r:.3f} kl {k:.3f}")
    torch.save({"state_dict": model.state_dict(), "vocab": tok.idx2tok}, "vae_curriculum.pt")
    print("✅ training complete; model saved → vae_curriculum.pt")


### 5. Molecule Generation and Property Evaluation

**Objective:**
This cell uses the trained VAE model (after curriculum learning) to generate a large library of new, unseen molecules. It then evaluates the chemical properties of these generated molecules to assess the model's performance and create a dataset for further analysis.

**Methodology:**
1.  **Sampling Function:** The `sample` function is defined to generate new molecules. It works by:
    * Sampling random points (`z`) from the model's latent space.
    * Using the VAE's decoder to generate new sequences of SELFIES tokens, starting from a `<BOS>` (Beginning of Sequence) token.
    * The generation is done autoregressively, where the token predicted at each step is fed as the input for the next step.
    * Finally, the generated SELFIES strings are decoded into their corresponding SMILES representations.
2.  **Generation & Evaluation:**
    * A large batch of `N_SAMPLES` (5,000) molecules is generated by calling the `sample` function.
    * The code then iterates through the list of generated SMILES strings. Each valid SMILES is converted into an RDKit molecule object.
    * For each valid molecule, three key medicinal chemistry properties are calculated: **QED** (Quantitative Estimation of Drug-likeness), **SA Score** (Synthetic Accessibility), and **MW** (Molecular Weight).
    * These properties, along with the SMILES string, are stored in a pandas DataFrame.

**Outcome:**
A DataFrame containing 4,989 valid molecules and their calculated properties is created. The descriptive statistics (mean, std, etc.) of these properties are printed to provide an initial assessment of the generated chemical space. The entire dataset is then saved to `generated_molecules.csv` for the next stage of filtering and visualization.

In [None]:
# ---------- sampling ----------
def sample(model, tokenizer, *, max_len=150, n=32, device="cpu"):
    model.eval()
    with torch.no_grad():
        z = torch.randn(n, model.mu.out_features, device=device)
        seq = torch.full((n, 1), tokenizer.bos, dtype=torch.long, device=device)
        hidden = None
        for _ in range(max_len):
            dec_in  = torch.cat([model.emb(seq[:, -1]), z], dim=-1).unsqueeze(1)
            dec_out, hidden = model.decoder(dec_in, hidden)
            logits  = model.out_proj(dec_out.squeeze(1))   # alias set in VAE
            logits[:, tokenizer.bos] = -1e9
            next_tok = torch.multinomial(torch.softmax(logits, dim=-1), 1)
            seq = torch.cat([seq, next_tok], dim=1)
            if (next_tok == tokenizer.eos).all():
                break

    selfies_list = [
        "".join(
            tokenizer.idx2tok[idx]
            for idx in s[1:]
            if idx not in (tokenizer.pad, tokenizer.eos)
        )
        for s in seq
    ]
    smiles = []
    for se in selfies_list:
        try:
            smiles.append(sf.decoder(se))
        except sf.DecoderError:
            smiles.append("")
    return smiles



# ================================
# Sampling + QED / SA evaluation
# ================================
N_SAMPLES, MAX_LEN = 5000, 150
CSV_NAME = "generated_molecules.csv"

smiles_list = sample(model, tok, n=N_SAMPLES, max_len=MAX_LEN, device=DEVICE)

records = []
for smi in smiles_list:
    mol = Chem.MolFromSmiles(smi)
    if mol is None:
        continue
    records.append(
        {
            "smiles": smi,
            "qed": QED.qed(mol),
            "sa": sascorer.calculateScore(mol),
            "mw": Descriptors.MolWt(mol),
        }
    )

df = pd.DataFrame(records)
print(f"Kept {len(df)}/{N_SAMPLES} molecules ({len(df)/N_SAMPLES:.1%})")
print(df[["qed", "sa", "mw"]].describe())

df.to_csv(CSV_NAME, index=False)
print(f"✅  CSV saved to {CSV_NAME}")


### 6. Hit Selection and Visualization of Generated Molecules

**Objective:**
This cell filters the large library of molecules generated by the curriculum-trained VAE to identify a smaller subset of high-quality "hits" based on established medicinal chemistry rules. It also generates plots to visualize the property distributions of the entire generated library, providing insight into the model's performance.

**Methodology:**
1.  **Hit Identification:** A set of filtering criteria is defined to select the most promising drug-like molecules from the `generated_molecules.csv` file:
    * **QED (Quantitative Estimation of Drug-likeness):** Must be greater than or equal to 0.6.
    * **SA Score (Synthetic Accessibility):** Must be less than or equal to 5.0 (a lower score indicates the molecule is easier to synthesize).
    * **Molecular Weight (MW):** Must be within the 150-600 g/mol range, consistent with Lipinski's Rule of Five for oral bioavailability.
    
    The DataFrame is filtered using these rules, and the resulting "hits" are counted and saved to a new DataFrame.
        
2.  **Visualization:** Three plots are generated using `matplotlib` to analyze the property space of the generated molecules:
    * **QED & SA Score Histograms:** These plots show the frequency distribution of the drug-likeness and synthetic accessibility scores for all generated molecules. A red dashed line on each plot indicates the threshold used for filtering.
    * **QED vs. SA Score Scatter Plot:** This plot visualizes the relationship between the two key properties for all molecules. The red lines define a "hit region" in the bottom-right quadrant, which contains the most desirable molecules (high QED, low SA).

**Outcome:**
The top candidate molecules that pass all filters are saved to a new file, `VAE_top_hits.csv`. The generated plots provide a clear visual summary of the model's output, showing that while it explores a broad chemical space, a specific subset of high-quality molecules can be successfully isolated using these standard filters.

In [None]:
# ==============================================================
# 1) Filter "Hit" Molecules + 2) Plot Distributions and Scatter
#    (Assuming df is the DataFrame generated in the previous cell)
# ==============================================================

# ---------- 1. Select Hits ----------
# Suggested criteria; adjust as desired
QED_MIN   = 0.60          # Drug-likeness
SA_MAX    = 5.0           # Acceptable synthetic accessibility
MW_MIN    = 150           # Remove very light ones (often solvents or small fragments)
MW_MAX    = 600           # According to Lipinski's Rule

hits = df[
    (df.qed >= QED_MIN) &
    (df.sa  <= SA_MAX)  &
    (df.mw  >= MW_MIN)  &
    (df.mw  <= MW_MAX)
].copy()

print(f"🎯  Found {len(hits)} high-quality molecules out of {len(df)}  "
      f"({len(hits)/len(df):.1%})")

# Save Hits
hits.to_csv("VAE_top_hits.csv", index=False)
print("✅  Saved to top_hits.csv")


# ---------- 2. Plot Distributions ----------
plt.figure(figsize=(14,4))

# ---QED Histogram ---
plt.subplot(1,3,1)
plt.hist(df.qed, bins=40)
plt.axvline(QED_MIN, color='red', linestyle='--')
plt.title("QED distribution")
plt.xlabel("QED")
plt.ylabel("Count")

# ---SA Histogram---
plt.subplot(1,3,2)
plt.hist(df.sa, bins=40)
plt.axvline(SA_MAX, color='red', linestyle='--')
plt.title("SA Score distribution")
plt.xlabel("SA (lower = easier)")
plt.ylabel("Count")

# --- Scatter QED vs SA ---
plt.subplot(1,3,3)
plt.scatter(df.qed, df.sa, s=10, alpha=0.4)
plt.axvline(QED_MIN, color='red', linestyle='--')
plt.axhline(SA_MAX, color='red', linestyle='--')
plt.title("QED vs SA")
plt.xlabel("QED (higher is better)")
plt.ylabel("SA (lower is better)")

plt.tight_layout()
plt.show()

### 7. Reinforcement Learning (RL) for Targeted Molecular Optimization

**Objective:**
This cell performs the final and most crucial optimization step. It fine-tunes the previously trained VAE model using a Reinforcement Learning (RL) approach. The goal is to steer the model's generative process specifically towards molecules that not only are drug-like but also possess high **3D structural viability**.

**Methodology:**
This stage implements a policy gradient-based RL loop that directly optimizes the VAE's decoder:

1.  **Encoder Freezing:** The weights of the VAE's encoder are frozen. This means only the decoder (the generative part of the model) will be updated during this phase. The model will leverage its learned understanding of chemical space from the pre-training stage while adapting its generation strategy.
2.  **Custom Reward Function:** A sophisticated reward function is defined to score each generated molecule. This score serves as the feedback signal for the RL agent. The reward `R` is calculated based on three components:
    * **Chemical Properties (QED & SA Score):** A primary score is calculated to favor molecules with high QED (drug-likeness) and low SA Score (ease of synthesis).
    * **3D Embeddability (Penalty):** Crucially, the code attempts to generate a 3D conformer for each molecule using `AllChem.EmbedMolecule()`. If this process fails (returns `-1`), the molecule is considered physically unrealistic, and a **strong negative reward (penalty)** of `-1.0` is assigned. This powerfully discourages the model from generating impossible structures.
    * **High-Quality Hits (Bonus):** Molecules that meet the stringent "hit" criteria (`QED > 0.6` and `SA < 5.0`) receive an **additional bonus reward**, further encouraging the model to explore the most promising regions of the chemical space.
3.  **RL Training Loop:**
    * At each epoch, the model **samples** a batch of new molecules.
    * The **reward** for each valid molecule is calculated using the function described above.
    * The model's weights are updated using a policy gradient algorithm. The loss is calculated to increase the log-probability of generating molecules that received a high reward relative to a moving average baseline (`baseline`). This stabilizes training and focuses the model on generating better-than-average compounds.

**Outcome:**
The training loop fine-tunes the model over several epochs, with the mean reward progressively increasing. This process results in a highly specialized generative model whose weights are saved to `vae_selfies_rl.pt`. This final model is optimized not just for general drug-likeness but also for producing novel molecules that are structurally and physically plausible.

In [None]:
# =========================================================
# RL fine-tune (QED ↑ , SA ↓) – Compatible with Current Tokenizer/Model
# =========================================================

N_EPOCHS   = 3
BATCH_SIZE = 256
MAX_LEN    = 150
LAMBDA_SA  = 0.30
LR, CLIP   = 1e-4, 1.0
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
GAMMA      = 0.9      # baseline decay

# ---------- Freeze Encoder ----------
for p in (*model.encoder.parameters(), *model.mu.parameters(), *model.logvar.parameters()):
    p.requires_grad_(False)

opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
baseline = 0.0

# ---------- Sample with Tokens ----------
def sample_with_tokens(model, tok, n, max_len, device):
    model.eval()
    with torch.no_grad():
        z   = torch.randn(n, model.mu.out_features, device=device)
        seq = torch.full((n, 1), tok.bos, dtype=torch.long, device=device)
        hidden = None
        for _ in range(max_len):
            dec_in  = torch.cat([model.emb(seq[:, -1]), z], -1).unsqueeze(1)
            dec_out, hidden = model.decoder(dec_in, hidden)
            logits = model.out_proj(dec_out.squeeze(1))
            logits[:, tok.bos] = -1e9
            nxt = torch.multinomial(torch.softmax(logits, -1), 1)
            seq = torch.cat([seq, nxt], 1)
            if (nxt == tok.eos).all():
                break
    selfies = ["".join(tok.idx2tok[i] for i in row[1:] if i not in (tok.pad, tok.eos)) for row in seq]
    return selfies, seq[:, 1:]         # SMILES will be converted later

model.to(DEVICE)

# ... (previous RL cell code) ...

for ep in range(1, N_EPOCHS + 1):
    sf_batch, toks_batch = sample_with_tokens(model, tok, BATCH_SIZE, MAX_LEN, DEVICE)

    model.train()
    rewards, logps = [], []
    valid_smiles_count = 0 # To count valid molecules

    for se, toks in zip(sf_batch, toks_batch):
        smi = ""
        try:
            smi = sf.decoder(se)
            mol = Chem.MolFromSmiles(smi)
            if mol is None: continue
        except sf.DecoderError:
            continue

        # --- Key modified section ---
        mol_h = Chem.AddHs(mol)
        # Attempt to build 3D structure
        # embed_result = AllChem.EmbedMolecule(mol_h, AllChem.ETKDG(), randomSeed=42)

        # 1. First, create a parameter object
        params = AllChem.ETKDG()
        # 2. Set the desired parameter (randomSeed) in it
        params.randomSeed = 42
        # 3. Now call the function with the modified parameter object
        embed_result = AllChem.EmbedMolecule(mol_h, params)

        R = 0.0 # Initial reward value
        if embed_result == -1:
            # 1. Penalty for molecules that cannot be built in 3D
            R = -1.0
        else:
            # 2. Calculate reward for valid molecules
            qed_score = QED.qed(mol)
            sa_score = sascorer.calculateScore(mol)
            R = 2 * qed_score - LAMBDA_SA * (sa_score / 10.0)

            # 3. (Optional) Bonus for excellent hits
            if qed_score > 0.6 and sa_score < 5.0:
                R += 0.5

        rewards.append(R)
        valid_smiles_count += 1

        # ... rest of logps calculation code ...
        # (This section does not change)
        inp = torch.cat([torch.tensor([tok.bos], device=DEVICE), toks[:-1].to(DEVICE)])
        tgt = toks.to(DEVICE)
        logits, _, _ = model(inp.unsqueeze(0), tgt.unsqueeze(0),
                             torch.tensor([len(tgt)], device=DEVICE),
                             tfr=0.0)
        lp = torch.log_softmax(logits, -1)
        seq_lp = lp.gather(2, tgt.unsqueeze(0).unsqueeze(-1)).squeeze().sum()
        logps.append(seq_lp)

    if not rewards:
        print(f"[RL-Ep {ep}] skip (0 valid molecules)")
        continue

    rewards = torch.tensor(rewards, device=DEVICE)
    logps   = torch.stack(logps)
    baseline = GAMMA * baseline + (1 - GAMMA) * rewards.mean().item()
    loss = -((rewards - baseline) * logps).mean()

    opt.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
    opt.step()

    hits = (rewards > 0.5).float().mean()
    # In the report, we also print the number of valid molecules
    print(f"[RL-Ep {ep}] meanR={rewards.mean():.3f} | hits>0.5={hits:.2%} | valid {valid_smiles_count}/{BATCH_SIZE}")

### 8. Final Model Checkpointing and Large-Scale Generation

**Objective:**
This final code cell serves two primary purposes: first, it saves the final, optimized state of the Reinforcement Learning model for future use and reproducibility. Second, it uses this final model to generate a large library of 10,000 new molecules, which will be the basis for the final analysis and hit selection.

**Methodology:**
1.  **Model Saving:** The state dictionary (`state_dict`) of the fine-tuned RL model, along with its associated vocabulary, is saved to the file `vae_selfies_rl.pt` using `torch.save`. This checkpoint captures the fully optimized generator.
2.  **Model Loading:** To ensure a clean sampling process, a new instance of the VAE model (`model_rl_loaded`) is created. The saved weights from the checkpoint file are then loaded into this new model instance. The model is set to evaluation mode (`.eval()`) to disable dropout and other training-specific behaviors.
3.  **Large-Scale Sampling:** The `sample` function is called to generate `N_SAMPLES` (10,000) new molecules from the loaded RL-tuned model. This creates a diverse library of compounds that reflect the model's optimized generative policy.
4.  **Results Storage:** The generated molecules are processed and saved in two formats:
    * **SMILES File:** All generated SMILES strings are written to `rl_samples_10k.smi`.
    * **Properties CSV:** For each valid molecule, its QED, SA score, and molecular weight are calculated. These properties are compiled into a pandas DataFrame and saved to `rl_samples_props.csv`.

**Outcome:**
This cell produces the final outputs of the generative pipeline: a saved, reusable model (`vae_selfies_rl.pt`) and a comprehensive dataset (`rl_samples_props.csv`) of novel molecules and their key properties, ready for final visualization and analysis.

In [None]:
# --------- 1) Save RL-fine-tuned weights -----------
torch.save(
    {
        "state_dict": model.state_dict(),
        "vocab": tok.idx2tok,     # For rebuilding the Tokenizer
        "rl_epochs": N_EPOCHS,
        "lambda_sa": LAMBDA_SA
    },
    "vae_selfies_rl.pt"
)
print("✅ Optimized model saved → vae_selfies_rl.pt")


# --------- 2) Sample from RL-tuned model and save results ----------

SMI_OUT = "rl_samples_10k.smi"
CSV_OUT = "rl_samples_props.csv"
N_SAMPLES = 10000

# 1. Create a new instance of the model
#    Note: We use the 'tok' variable created in cell [11] because the vocabulary is the same
model_rl_loaded = VAE(len(tok.idx2tok), tok.pad, tok.bos).to(DEVICE)

# 2. Load the saved weights into the new model
checkpoint = torch.load("vae_selfies_rl.pt", map_location=DEVICE)
model_rl_loaded.load_state_dict(checkpoint["state_dict"])
model_rl_loaded.eval()  # Important: Set the model to evaluation mode

print("✅ Fine-tuned model successfully loaded.")

# 3. Sample using the new model (model_rl_loaded)
print(f"⏳ Sampling from RL model with {N_SAMPLES} samples...")
smiles_list = sample(model_rl_loaded, tok, n=N_SAMPLES, max_len=MAX_LEN, device=DEVICE)


# ... rest of the code for saving to SMI and CSV files ...
# (These sections are correct and do not need modification)

# SMI file
with open(SMI_OUT, "w") as f:
    for smi in smiles_list:
        if smi: # Only save non-empty strings
            f.write(smi + "\n")
print(f"✅ {len(smiles_list)} SMILES generated and saved to → {SMI_OUT}")

# CSV with QED / SA / MW
rows = []
valid_count = 0
for smi in smiles_list:
    mol = Chem.MolFromSmiles(smi)
    if mol:
        valid_count += 1
        rows.append(
            {
                "smiles": smi,
                "qed": QED.qed(mol),
                "sa": sascorer.calculateScore(mol),
                "mw": Descriptors.MolWt(mol),
            }
        )
pd.DataFrame(rows).to_csv(CSV_OUT, index=False)
print(f"✅ Properties for {valid_count} valid molecules saved to → {CSV_OUT}")

### 9. Final Analysis: Hit Selection and Visualization of RL-Tuned Molecules

**Objective:**
This final cell analyzes the large library of molecules generated by the **Reinforcement Learning (RL) optimized model**. The goal is to identify the highest-quality "hit" compounds and to visualize the property distributions of the final generated chemical space. This provides a clear assessment of the success of the RL optimization strategy.

**Methodology:**
1.  **Data Loading:** The `rl_samples_props.csv` file, which contains the 10,000 molecules generated from the final RL model, is loaded into a pandas DataFrame.
2.  **Hit Identification:** The same stringent filtering criteria used previously (QED ≥ 0.6, SA Score ≤ 5.0, MW between 150-600) are applied to this new dataset to identify the top candidates.
3.  **Visualization:** A series of plots are generated to visualize the property distributions of the RL-tuned molecules:
    * **QED & SA Score Histograms:** These plots show the frequency distribution for drug-likeness and synthetic accessibility. The red dashed lines indicate the filtering thresholds. These plots are critical for comparing how the RL process has shifted the distribution of generated molecules towards more desirable regions compared to the pre-trained model.
    * **QED vs. SA Score Scatter Plot:** This plot provides a global view of the chemical property space, with the "hit region" clearly marked. The density and location of the points illustrate the model's success in focusing its generation on high-reward areas.

**Outcome:**
The top candidate molecules from the RL model are identified and saved to `top_hits.csv`. The visualizations clearly demonstrate the impact of the RL fine-tuning. By comparing these plots to those from the previous model, we can quantitatively and qualitatively measure the improvement and confirm that the model has successfully learned to generate molecules with more favorable drug-like properties.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# ==============================================================
# 1) Filter "Hit" Molecules + 2) Plot Distributions and Scatter
#    (Assuming df is the DataFrame generated in the previous cell)
# ==============================================================
# ---------- 1. Select Hits ----------
# Suggested criteria; adjust as desired
QED_MIN   = 0.60          # Drug-likeness
SA_MAX    = 5.0           # Acceptable synthetic accessibility
MW_MIN    = 150           # Remove very light ones (often solvents or small fragments)
MW_MAX    = 600           # According to Lipinski's Rule
df_rl = pd.read_csv("rl_samples_props.csv")
hits = df_rl[
    (df_rl.qed >= QED_MIN) &
    (df_rl.sa  <= SA_MAX)  &
    (df_rl.mw  >= MW_MIN)  &
    (df_rl.mw  <= MW_MAX)
].copy()

print(f"🎯  Found {len(hits)} high-quality molecules out of {len(df_rl)}  "
      f"({len(hits)/len(df_rl):.1%})")

# Save Hits
hits.to_csv("RL_top_hits.csv", index=False)
print("✅  Saved to RL_top_hits.csv")


# ---------- 2. Plot Distributions ----------
plt.figure(figsize=(14,4))

# --- QED Histogram ---
plt.subplot(1,3,1)
plt.hist(df_rl.qed, bins=40)
plt.axvline(QED_MIN, color='red', linestyle='--')
plt.title("QED distribution")
plt.xlabel("QED")
plt.ylabel("Count")

# --- SA Histogram ---
plt.subplot(1,3,2)
plt.hist(df_rl.sa, bins=40)
plt.axvline(SA_MAX, color='red', linestyle='--')
plt.title("SA Score distribution")
plt.xlabel("SA (lower = easier)")
plt.ylabel("Count")

# --- Scatter QED vs SA ---
plt.subplot(1,3,3)
plt.scatter(df_rl.qed, df_rl.sa, s=10, alpha=0.4)
plt.axvline(QED_MIN, color='red', linestyle='--')
plt.axhline(SA_MAX, color='red', linestyle='--')
plt.title("QED vs SA")
plt.xlabel("QED (higher is better)")
plt.ylabel("SA (lower is better)")

plt.tight_layout()
plt.show()


### 10. Conclusion and Future Work

### Conclusion

This project successfully implemented an end-to-end pipeline for de novo molecular design using a deep generative model. The primary goal was to generate novel, drug-like molecules targeted at the EGFR kinase domain.

The key achievements of this work include:
* **Successful Data Pipeline:** Bioactivity data for EGFR was successfully retrieved from the ChEMBL database and preprocessed for model training.
* **Valid Molecule Generation:** By employing a Variational Autoencoder (VAE) with the **SELFIES** molecular representation, we completely overcame the common issue of invalid SMILES generation, achieving 100% structural validity.
* **Quality Improvement with Curriculum Learning:** The base VAE was effectively fine-tuned on a high-quality subset of drug-like molecules, successfully biasing the generator towards more favorable chemical space.
* **Targeted Optimization with Reinforcement Learning:** The final model was optimized using an RL agent with a custom reward function. The results clearly show that this process shifted the distribution of generated molecules towards higher **QED** (drug-likeness) and lower **SA Scores** (easier synthesis).

Ultimately, the project demonstrates the successful application of modern deep learning techniques to create a specialized model capable of generating novel and optimized molecules for a specific therapeutic target.

### Future Work

While this project established a robust generative pipeline, several exciting avenues remain for future exploration to bring these computational designs closer to real-world drug candidates:

1.  **More Extensive Training:** Re-run the Reinforcement Learning fine-tuning on more powerful hardware for a greater number of epochs. This would allow the model to better converge and likely lead to the generation of molecules with even higher reward scores.

2.  **Molecular Docking Simulation:** The immediate next step is to perform molecular docking to predict the binding affinity of the top-generated hits against the EGFR protein's active site (e.g., using PDB ID `1M17`). This would provide a direct, physics-based estimate of their inhibitory potential.
3.  **RL Optimization with Docking Score:** A more advanced approach would be to integrate the docking score directly into the Reinforcement Learning reward function. This would create a powerful feedback loop, training the model to generate molecules that are not only drug-like but also have a predicted high binding affinity for the target.
4.  **Molecular Dynamics (MD) Simulations:** For the most promising candidates identified through docking, MD simulations could be performed to assess the stability of the ligand-protein complex over time, providing deeper insights into the binding interactions.
5.  **Synthesis and In Vitro Assay:** The ultimate validation would be to synthesize the top computationally-ranked molecules in a laboratory and test their actual inhibitory activity against the EGFR protein through in vitro biological assays.