<a href="https://colab.research.google.com/github/Swayamprakashpatel/DD/blob/main/Conditional_VAE_Drug_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rdkit
!pip install torch

Collecting rdkit
  Downloading rdkit-2025.3.5-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Downloading rdkit-2025.3.5-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2025.3.5


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#NEW CODE WITH PROTEIN UPDATING: To make the model more effective, we need to create a richer, more meaningful representation of the protein sequence. We can do this with the data you already have by using a technique called transfer learning. This involves using a pre-trained protein language model (PLM).

#These models, like ESM or ProtT5, have been trained on millions of protein sequences and have learned a deep understanding of protein biology. Instead of a simple one-hot vector, they can transform a protein sequence into a dense numerical vector (an embedding) that captures a wealth of information about its function, structure, and evolutionary relationships.

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolToImage
from IPython.display import Image, display
import numpy as np
import warnings
import os
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# Suppress RDKit warnings for a cleaner output
warnings.filterwarnings("ignore")

# --- Global constants for a comprehensive SMILES vocabulary ---
def create_comprehensive_smiles_vocab():
    """
    Creates a comprehensive SMILES token vocabulary based on the periodic table,
    numbers, and common SMILES characters.
    """
    periodic_table_elements = [
        'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si',
        'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
        'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb',
        'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
        'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
        'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
        'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np',
        'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg',
        'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
    ]
    # Common SMILES single characters
    single_char_tokens = [
        '(', ')', '[', ']', '+', '-', '=', '#', '.', ':', '/', '\\',
        'c', 'n', 'o', 's', '`', '@'
    ]
    # Numbers
    numbers = [str(i) for i in range(10)]

    # Combine all tokens and remove duplicates
    all_tokens = set(periodic_table_elements + single_char_tokens + numbers)
    # Add start and padding tokens
    all_tokens.add('#') # Start token
    all_tokens.add('`') # Padding token

    # Sort tokens by length in descending order to handle multi-character tokens correctly
    sorted_tokens = sorted(list(all_tokens), key=len, reverse=True)
    # Create the dictionary
    smiles_token_dict = {token: i for i, token in enumerate(sorted_tokens)}

    return smiles_token_dict, sorted_tokens

SMILES_TOKEN_DICT, sorted_tokens = create_comprehensive_smiles_vocab()
SMILES_MAX_LEN = 500
MAX_PROTEIN_SEQUENCE_LEN = 1000  # New constant to control max protein sequence length
HIDDEN_DIM = 64
LATENT_DIM = 128
NUM_LAYERS = 2
MAX_EPOCHS = 1
BATCH_SIZE = 1000
PATIENCE = 15
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = '/content/drive/MyDrive/Drug_Database/BindingDB_CVAE_best_model.pt'

# --- Data Preprocessing and PyTorch Dataset ---
def get_smiles_tokens(smiles_string, special_tokens):
    """
    Tokenizes a SMILES string using a predefined list of special tokens.
    """
    tokens = []
    i = 0
    while i < len(smiles_string):
        matched = False
        for token in special_tokens:
            if smiles_string.startswith(token, i):
                tokens.append(token)
                i += len(token)
                matched = True
                break
        if not matched:
            tokens.append(smiles_string[i])
            i += 1
    return tokens

def smiles_to_one_hot(smiles, token_to_idx, max_len, sorted_tokens):
    """Encodes a SMILES string into a one-hot vector with padding."""
    tokens = get_smiles_tokens(smiles, sorted_tokens)

    if len(tokens) > max_len:
        tokens = tokens[:max_len]

    # Pad with padding token
    padded_tokens = tokens + ['`'] * (max_len - len(tokens))

    one_hot_vector = np.zeros((max_len, len(token_to_idx)), dtype=np.float32)
    for i, token in enumerate(padded_tokens):
        if token in token_to_idx:
            one_hot_vector[i, token_to_idx[token]] = 1
    return one_hot_vector

def get_protein_embedding(seq):
    """
    A simple, placeholder protein embedding function.
    In a real-world scenario, you would use a more sophisticated method,
    like a pre-trained protein language model (e.g., from the ProtTrans or ESM family)
    for better performance on complex protein sequences.
    """
    protein_vocab = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10,
                     'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
    embedding = np.zeros(len(protein_vocab), dtype=np.float32)
    for amino_acid in seq:
        if amino_acid in protein_vocab:
            embedding[protein_vocab[amino_acid]] = 1
    return embedding

def get_advanced_protein_embedding(seq):
    """
    A placeholder function for a more advanced protein embedding.

    In a real-world application, this function would use a pre-trained
    protein language model (PLM) like ESM-2 or ProtT5.

    Example using a conceptual library call:

    from transformers import T5EncoderModel, T5Tokenizer

    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_uniref50')
    model = T5EncoderModel.from_pretrained('Rostlab/prot_t5_xl_uniref50')

    # Tokenize and get embedding
    encoded_input = tokenizer(seq, return_tensors='pt', max_length=1000, padding=True, truncation=True)
    with torch.no_grad():
        embedding = model(**encoded_input).last_hidden_state.mean(dim=1).squeeze()

    return embedding.numpy()

    The dimension of the embedding would be around 1024, significantly larger
    and more informative than the simple one-hot encoding.

    For now, we'll return a placeholder vector to avoid a runtime error.
    """
    # Placeholder for the advanced embedding dimension
    ADVANCED_PROTEIN_EMBEDDING_DIM = 1024
    return np.random.rand(ADVANCED_PROTEIN_EMBEDDING_DIM).astype(np.float32)

class SmilesProteinDataset(Dataset):
    """Dataset for loading SMILES and Protein sequence pairs."""
    def __init__(self, smiles_list, protein_list, smiles_token_dict, smiles_max_len, sorted_tokens, embedding_function):
        self.smiles_list = smiles_list
        self.protein_list = protein_list
        self.smiles_token_dict = smiles_token_dict
        self.smiles_max_len = smiles_max_len
        self.sorted_tokens = sorted_tokens
        self.embedding_function = embedding_function
        # The protein embedding dimension is determined from the first sequence in the list
        self.protein_embedding_dim = len(self.embedding_function(self.protein_list[0]))

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        protein = self.protein_list[idx]
        smiles_one_hot = smiles_to_one_hot(smiles, self.smiles_token_dict, self.smiles_max_len, self.sorted_tokens)
        protein_embedding = self.embedding_function(protein)
        return torch.tensor(smiles_one_hot, dtype=torch.float32), torch.tensor(protein_embedding, dtype=torch.float32)

# --- Conditional VAE Model Definition ---

class ConditionalVAE(nn.Module):
    def __init__(self, smiles_vocab_size, protein_embedding_dim, hidden_dim, latent_dim, smiles_max_len, num_layers):
        super(ConditionalVAE, self).__init__()

        self.smiles_max_len = smiles_max_len
        self.smiles_vocab_size = smiles_vocab_size
        self.protein_embedding_dim = protein_embedding_dim

        # Encoder: takes smiles and protein embedding as input
        self.encoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        # NOTE: This MLP now needs to take the more complex protein embedding as input
        self.encoder_protein_mlp = nn.Sequential(
            nn.Linear(protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder: takes latent vector and protein embedding as input
        self.decoder_gru_input_mlp = nn.Sequential(
            nn.Linear(latent_dim + protein_embedding_dim, hidden_dim),
            nn.ReLU()
        )
        self.decoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_output = nn.Linear(hidden_dim, smiles_vocab_size)

    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from N(mu, var)"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder(self, smiles_one_hot, protein_embedding):
        _, hidden_smiles = self.encoder_rnn(smiles_one_hot)
        # Ensure protein embedding has the correct shape for broadcasting
        hidden_protein = self.encoder_protein_mlp(protein_embedding).unsqueeze(0).repeat(hidden_smiles.size(0), 1, 1)
        combined_hidden = hidden_smiles + hidden_protein
        final_hidden_state = combined_hidden[-1, :, :]
        mu = self.fc_mu(final_hidden_state)
        logvar = self.fc_logvar(final_hidden_state)
        return mu, logvar

    def decoder(self, z, protein_embedding):
        z_with_protein = torch.cat((z, protein_embedding), dim=1)
        initial_hidden = self.decoder_gru_input_mlp(z_with_protein).unsqueeze(0)
        initial_hidden = initial_hidden.repeat(self.decoder_rnn.num_layers, 1, 1)
        decoder_input = torch.zeros(z.size(0), self.smiles_max_len, self.smiles_vocab_size).to(z.device)
        output, _ = self.decoder_rnn(decoder_input, initial_hidden)
        output = self.fc_output(output)
        return output

    def forward(self, smiles_one_hot, protein_embedding):
        mu, logvar = self.encoder(smiles_one_hot, protein_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_smiles = self.decoder(z, protein_embedding)
        return reconstructed_smiles, mu, logvar

# --- Training Loop and Generation ---

def vae_loss(recon_x, x, mu, logvar, smiles_vocab_size):
    """VAE loss function: BCE + KL divergence."""
    # Ensure dimensions are compatible for BCE calculation
    recon_x_flat = recon_x.view(-1, smiles_vocab_size)
    x_flat = x.view(-1, smiles_vocab_size)
    BCE = nn.functional.binary_cross_entropy_with_logits(recon_x_flat, x_flat, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_model(model, dataloader, optimizer, device, smiles_vocab_size):
    """Trains the VAE model for one epoch with a progress bar."""
    model.train()
    total_loss = 0
    tqdm_loader = tqdm(dataloader, desc="Training")
    for smiles_oh, protein_emb in tqdm_loader:
        smiles_oh = smiles_oh.to(device)
        protein_emb = protein_emb.to(device)

        optimizer.zero_grad()

        reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
        loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        # Update the progress bar with the current batch loss
        tqdm_loader.set_postfix(batch_loss=loss.item() / smiles_oh.size(0))

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def validate_model(model, dataloader, device, smiles_vocab_size):
    """Evaluates the model on the validation set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        tqdm_loader = tqdm(dataloader, desc="Validation")
        for smiles_oh, protein_emb in tqdm_loader:
            smiles_oh = smiles_oh.to(device)
            protein_emb = protein_emb.to(device)

            reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
            loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)
            total_loss += loss.item()
            # Update the progress bar with the current batch loss
            tqdm_loader.set_postfix(batch_loss=loss.item() / smiles_oh.size(0))

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def generate_new_molecules(model, smiles_token_dict, protein_seq, num_molecules, device, latent_dim):
    """Generates and visualizes new molecules for a given protein sequence."""
    model.eval()
    idx_to_token = {v: k for k, v in smiles_token_dict.items()}
    generated_molecules = []

    # Prepare the protein embedding for a single sequence
    # NOTE: We now use the advanced embedding function here too
    protein_embedding = get_advanced_protein_embedding(protein_seq)
    protein_embedding = torch.tensor(protein_embedding, dtype=torch.float32).unsqueeze(0).to(device)

    print(f"\nAttempting to generate {num_molecules} molecules for protein sequence: {protein_seq}")
    with torch.no_grad():
        for i in range(num_molecules):
            # Sample a vector from the latent space
            z = torch.randn(1, latent_dim).to(device)
            # Use the decoder to generate a one-hot encoded SMILES string
            generated_output = model.decoder(z, protein_embedding)

            # Use softmax to get probabilities for each token and then select the most likely one
            probabilities = nn.functional.softmax(generated_output, dim=-1)
            predicted_indices = torch.argmax(probabilities, dim=-1).squeeze(0).cpu().numpy()

            generated_smiles = ""
            for token_idx in predicted_indices:
                token = idx_to_token.get(token_idx, '')
                # Stop decoding when a padding token is encountered
                if token == '`':
                    break
                generated_smiles += token

            # Validate and display the generated SMILES string as an image
            mol = Chem.MolFromSmiles(generated_smiles)
            if mol is not None:
                print(f"Generated molecule (valid SMILES): {generated_smiles}")
                # You can't display images directly in this environment, but the code is ready for it.
            else:
                print(f"Generated molecule (invalid SMILES): {generated_smiles}")

            generated_molecules.append(generated_smiles)

    return generated_molecules

if __name__ == '__main__':
    # Determine the device to use (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define the path to the newly created CSV file in Google Drive
    csv_file_path = '/content/drive/MyDrive/Drug_Database/bindingdb_dataset.csv'

    # Define an example protein sequence to predict for
    target_protein_sequence = 'MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL'

    # --- Data Loading ---
    try:
        # Load the CSV file from the specified Google Drive path
        df = pd.read_csv(csv_file_path)
        # Ensure the column names match the new dataset
        smiles_list = df['SMILES'].tolist()
        protein_list = df['Protein_Sequence'].tolist()
        print(f"Successfully loaded {len(smiles_list)} data points from '{csv_file_path}'.")
    except FileNotFoundError:
        print(f"Error: The data file '{csv_file_path}' was not found.")
        print("Please ensure the file has been created by the previous script and is in the correct directory.")
        exit()

    # --- Data Filtering ---
    initial_count = len(smiles_list)
    filtered_smiles = []
    filtered_proteins = []
    for smiles, protein in zip(smiles_list, protein_list):
        # Filter out rows with missing or invalid protein sequences
        if pd.isna(protein) or not isinstance(protein, str) or not protein.strip():
            continue
        # Then, apply the length filter
        if len(protein) <= MAX_PROTEIN_SEQUENCE_LEN:
            filtered_smiles.append(smiles)
            filtered_proteins.append(protein)

    smiles_list = filtered_smiles
    protein_list = filtered_proteins
    print(f"Filtered dataset: Retained {len(smiles_list)} rows out of {initial_count} "
          f"by removing invalid protein sequences and sequences of length > {MAX_PROTEIN_SEQUENCE_LEN}.")

    # Create dataset and dataloader
    # NOTE: We now pass the advanced embedding function to the dataset
    full_dataset = SmilesProteinDataset(smiles_list, protein_list, SMILES_TOKEN_DICT, SMILES_MAX_LEN, sorted_tokens, get_advanced_protein_embedding)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize the model and optimizer
    # NOTE: The protein embedding dimension is now dynamic based on the advanced embedding model
    protein_embedding_dim = full_dataset.protein_embedding_dim
    model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    patience_counter = 0

    print("Starting training...")
    for epoch in range(MAX_EPOCHS):
        print(f"--- Epoch {epoch+1}/{MAX_EPOCHS} ---")
        train_loss = train_model(model, train_dataloader, optimizer, device, len(SMILES_TOKEN_DICT))
        val_loss = validate_model(model, val_dataloader, device, len(SMILES_TOKEN_DICT))

        print(f"Epoch {epoch+1}/{MAX_EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Check for improvement and save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print(f"Validation loss improved. Saving model to {MODEL_SAVE_PATH}")
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{PATIENCE}")

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print(f"\nTraining complete. Model saved to '{MODEL_SAVE_PATH}'.")

    # --- Generation Process ---
    # We load the best model to ensure we are using the one with the best performance
    best_model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    best_model.load_state_dict(torch.load(MODEL_SAVE_PATH))

    # Generate new molecules using the loaded model
    generate_new_molecules(best_model, SMILES_TOKEN_DICT, target_protein_sequence, num_molecules=5, device=device, latent_dim=LATENT_DIM)


Using device: cuda
Successfully loaded 3046040 data points from '/content/drive/MyDrive/Drug_Database/bindingdb_dataset.csv'.
Filtered dataset: Retained 2479480 rows out of 3046040 by removing invalid protein sequences and sequences of length > 1000.
Starting training...
--- Epoch 1/1 ---


Training: 100%|██████████| 1736/1736 [39:52<00:00,  1.38s/it, batch_loss=218]
Validation: 100%|██████████| 744/744 [16:23<00:00,  1.32s/it, batch_loss=224]

Epoch 1/1, Train Loss: 876.2733, Validation Loss: 224.8349
Validation loss improved. Saving model to best_model.pt

Training complete. Model saved to 'best_model.pt'.

Attempting to generate 5 molecules for protein sequence: MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL
Generated molecule (invalid SMILES): CCCccccccccccccccccccccccccccccccccccccccccccccccc
Generated molecule (invalid SMILES): CCCCCccccccccccccccccccccccccccccccccccccccccccccccccccccc
Generated molecule (invalid SMILES): CCccccccccccccccccccccccccccccccccc
Generated molecule (invalid SMILES): CCCcccccccccccccccccccccccccccccccccccccccccc
Generated molecule (invalid SMILES): CCCCccccccccccccccccccccccccccccccccccc


[06:12:26] non-ring atom 3 marked aromatic
[06:12:26] non-ring atom 5 marked aromatic
[06:12:26] non-ring atom 2 marked aromatic
[06:12:26] non-ring atom 3 marked aromatic
[06:12:26] non-ring atom 4 marked aromatic


In [None]:
# NEW CODE FOR COLAB WITH REAL PROTEIN UPDATING
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolToImage
from IPython.display import Image, display
import numpy as np
import warnings
import os
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from transformers import AutoTokenizer, EsmModel
import gc
from google.colab import drive

# Mount Google Drive to access the dataset
drive.mount('/content/drive')

# Install missing dependencies if needed
!pip install einops

# Suppress RDKit warnings for a cleaner output
warnings.filterwarnings("ignore")

# --- Global constants for a comprehensive SMILES vocabulary ---
def create_comprehensive_smiles_vocab():
    """
    Creates a comprehensive SMILES token vocabulary.
    """
    periodic_table_elements = [
        'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si',
        'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
        'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb',
        'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
        'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
        'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
        'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np',
        'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg',
        'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
    ]
    single_char_tokens = [
        '(', ')', '[', ']', '+', '-', '=', '#', '.', ':', '/', '\\',
        'c', 'n', 'o', 's', '`', '@'
    ]
    numbers = [str(i) for i in range(10)]
    all_tokens = set(periodic_table_elements + single_char_tokens + numbers)
    all_tokens.add('#') # Start token
    all_tokens.add('`') # Padding token
    sorted_tokens = sorted(list(all_tokens), key=len, reverse=True)
    smiles_token_dict = {token: i for i, token in enumerate(sorted_tokens)}
    return smiles_token_dict, sorted_tokens

SMILES_TOKEN_DICT, sorted_tokens = create_comprehensive_smiles_vocab()
SMILES_MAX_LEN = 500
MAX_PROTEIN_SEQUENCE_LEN = 1000
HIDDEN_DIM = 128
LATENT_DIM = 64
NUM_LAYERS = 2
MAX_EPOCHS = 100
BATCH_SIZE = 64
PATIENCE = 25
LEARNING_RATE = 1e-4
MODEL_SAVE_PATH = '/content/drive/MyDrive/Drug_Database/BindingDB_CVAE_best_model_esm.pt'
CSV_FILE_PATH = '/content/drive/MyDrive/Drug_Database/bindingdb_dataset.csv'

# --- ESM Model Loading and Embedding Function ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained ESM-2 model and tokenizer
# ***STRATEGY CHANGE: LOAD ESM MODEL TO CPU TO SAVE GPU VRAM***
ESM_MODEL_NAME = 'esm2_t12_35M_UR50D'
print(f"Loading ESM-2 model: {ESM_MODEL_NAME} to CPU...")
esm_tokenizer = AutoTokenizer.from_pretrained(f"facebook/{ESM_MODEL_NAME}")
esm_model = EsmModel.from_pretrained(f"facebook/{ESM_MODEL_NAME}")
esm_model.eval()
ESM_EMBEDDING_DIM = esm_model.config.hidden_size
print(f"ESM model loaded successfully to CPU with embedding dimension: {ESM_EMBEDDING_DIM}")

def get_esm_embedding(seq):
    """
    Generates a protein embedding using a pre-trained ESM-2 model.
    """
    # Keep the model on CPU for most of the time
    with torch.no_grad():
        seq_spaced = " ".join(seq)
        encoded_input = esm_tokenizer(seq_spaced, return_tensors='pt',
                                      max_length=MAX_PROTEIN_SEQUENCE_LEN,
                                      padding=True, truncation=True)
        # Move input to GPU only for the forward pass
        encoded_input = {key: val.to(device) for key, val in encoded_input.items()}
        # Move model to GPU for a single forward pass
        esm_model_gpu = esm_model.to(device)
        output = esm_model_gpu(**encoded_input)
        # Move model back to CPU and clear GPU cache
        esm_model_gpu.to('cpu')
        torch.cuda.empty_cache()
        embedding = output.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

    return embedding

# --- Data Preprocessing and PyTorch Dataset ---
def get_smiles_tokens(smiles_string, special_tokens):
    tokens = []
    i = 0
    while i < len(smiles_string):
        matched = False
        for token in special_tokens:
            if smiles_string.startswith(token, i):
                tokens.append(token)
                i += len(token)
                matched = True
                break
        if not matched:
            tokens.append(smiles_string[i])
            i += 1
    return tokens

def smiles_to_one_hot(smiles, token_to_idx, max_len, sorted_tokens):
    tokens = get_smiles_tokens(smiles, sorted_tokens)
    if len(tokens) > max_len:
        tokens = tokens[:max_len]
    padded_tokens = tokens + ['`'] * (max_len - len(tokens))
    one_hot_vector = np.zeros((max_len, len(token_to_idx)), dtype=np.float32)
    for i, token in enumerate(padded_tokens):
        if token in token_to_idx:
            one_hot_vector[i, token_to_idx[token]] = 1
    return one_hot_vector

class SmilesProteinDataset(Dataset):
    """Dataset for loading SMILES and Protein sequence pairs."""
    def __init__(self, smiles_list, protein_list, smiles_token_dict, smiles_max_len, sorted_tokens, embedding_function):
        self.smiles_list = smiles_list
        self.protein_list = protein_list
        self.smiles_token_dict = smiles_token_dict
        self.smiles_max_len = smiles_max_len
        self.sorted_tokens = sorted_tokens
        self.embedding_function = embedding_function
        # Removed the protein_embeddings_cache. Embeddings will be generated on-the-fly.

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        protein = self.protein_list[idx]
        smiles_one_hot = smiles_to_one_hot(smiles, self.smiles_token_dict, self.smiles_max_len, self.sorted_tokens)

        # Generate embedding on-the-fly
        protein_embedding = self.embedding_function(protein)

        return torch.tensor(smiles_one_hot, dtype=torch.float32), torch.tensor(protein_embedding, dtype=torch.float32)

# --- Conditional VAE Model Definition ---
class ConditionalVAE(nn.Module):
    def __init__(self, smiles_vocab_size, protein_embedding_dim, hidden_dim, latent_dim, smiles_max_len, num_layers):
        super(ConditionalVAE, self).__init__()
        self.smiles_max_len = smiles_max_len
        self.smiles_vocab_size = smiles_vocab_size
        self.protein_embedding_dim = protein_embedding_dim

        self.encoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.encoder_protein_mlp = nn.Sequential(
            nn.Linear(protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        self.decoder_gru_input_mlp = nn.Sequential(
            nn.Linear(latent_dim + protein_embedding_dim, hidden_dim),
            nn.ReLU()
        )
        self.decoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_output = nn.Linear(hidden_dim, smiles_vocab_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder(self, smiles_one_hot, protein_embedding):
        _, hidden_smiles = self.encoder_rnn(smiles_one_hot)
        hidden_protein = self.encoder_protein_mlp(protein_embedding).unsqueeze(0).repeat(hidden_smiles.size(0), 1, 1)
        combined_hidden = hidden_smiles + hidden_protein
        final_hidden_state = combined_hidden[-1, :, :]
        mu = self.fc_mu(final_hidden_state)
        logvar = self.fc_logvar(final_hidden_state)
        return mu, logvar

    def decoder(self, z, protein_embedding):
        z_with_protein = torch.cat((z, protein_embedding), dim=1)
        initial_hidden = self.decoder_gru_input_mlp(z_with_protein).unsqueeze(0)
        initial_hidden = initial_hidden.repeat(self.decoder_rnn.num_layers, 1, 1)
        decoder_input = torch.zeros(z.size(0), self.smiles_max_len, self.smiles_vocab_size).to(z.device)
        output, _ = self.decoder_rnn(decoder_input, initial_hidden)
        output = self.fc_output(output)
        return output

    def forward(self, smiles_one_hot, protein_embedding):
        mu, logvar = self.encoder(smiles_one_hot, protein_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_smiles = self.decoder(z, protein_embedding)
        return reconstructed_smiles, mu, logvar

# --- Training Loop and Generation ---
def vae_loss(recon_x, x, mu, logvar, smiles_vocab_size):
    """VAE loss function: BCE + KL divergence."""
    recon_x_flat = recon_x.view(-1, smiles_vocab_size)
    x_flat = x.view(-1, smiles_vocab_size)
    BCE = nn.functional.binary_cross_entropy_with_logits(recon_x_flat, x_flat, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_model(model, dataloader, optimizer, device, smiles_vocab_size):
    model.train()
    total_loss = 0
    tqdm_loader = tqdm(dataloader, desc="Training")
    for smiles_oh, protein_emb in tqdm_loader:
        smiles_oh = smiles_oh.to(device)
        protein_emb = protein_emb.to(device)
        optimizer.zero_grad()
        reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
        loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        tqdm_loader.set_postfix(batch_loss=loss.item() / smiles_oh.size(0))
    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def validate_model(model, dataloader, device, smiles_vocab_size):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        tqdm_loader = tqdm(dataloader, desc="Validation")
        for smiles_oh, protein_emb in tqdm_loader:
            smiles_oh = smiles_oh.to(device)
            protein_emb = protein_emb.to(device)
            reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
            loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)
            total_loss += loss.item()
            tqdm_loader.set_postfix(batch_loss=loss.item() / smiles_oh.size(0))
    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def generate_new_molecules(model, smiles_token_dict, protein_seq, num_molecules, device, latent_dim, sorted_tokens):
    model.eval()
    idx_to_token = {v: k for k, v in smiles_token_dict.items()}
    generated_molecules = []

    protein_embedding = get_esm_embedding(protein_seq)
    protein_embedding = torch.tensor(protein_embedding, dtype=torch.float32).unsqueeze(0).to(device)

    gc.collect()
    torch.cuda.empty_cache()

    print(f"\nAttempting to generate {num_molecules} molecules for protein sequence...")
    with torch.no_grad():
        for i in range(num_molecules):
            z = torch.randn(1, latent_dim).to(device)
            generated_output = model.decoder(z, protein_embedding)
            probabilities = nn.functional.softmax(generated_output, dim=-1)
            predicted_indices = torch.argmax(probabilities, dim=-1).squeeze(0).cpu().numpy()

            generated_smiles = ""
            for token_idx in predicted_indices:
                token = idx_to_token.get(token_idx, '')
                if token == '`':
                    break
                generated_smiles += token

            mol = Chem.MolFromSmiles(generated_smiles)
            if mol is not None:
                print(f"Generated molecule {i+1} (valid SMILES): {generated_smiles}")
            else:
                print(f"Generated molecule {i+1} (invalid SMILES): {generated_smiles}")
            generated_molecules.append(generated_smiles)
    return generated_molecules

if __name__ == '__main__':
    target_protein_sequence = 'MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL'

    try:
        df = pd.read_csv(CSV_FILE_PATH)
        smiles_list = df['SMILES'].tolist()
        protein_list = df['Protein_Sequence'].tolist()
        print(f"Successfully loaded {len(smiles_list)} data points from '{CSV_FILE_PATH}'.")
    except FileNotFoundError:
        print(f"Error: The data file '{CSV_FILE_PATH}' was not found.")
        print("Please ensure the file is located at '/content/drive/MyDrive/Drug_Database/bindingdb_dataset.csv'.")
        exit()

    initial_count = len(smiles_list)
    filtered_smiles = []
    filtered_proteins = []
    for smiles, protein in zip(smiles_list, protein_list):
        if pd.isna(protein) or not isinstance(protein, str) or not protein.strip():
            continue
        if len(protein) <= MAX_PROTEIN_SEQUENCE_LEN:
            filtered_smiles.append(smiles)
            filtered_proteins.append(protein)
    smiles_list = filtered_smiles
    protein_list = filtered_proteins
    print(f"Filtered dataset: Retained {len(smiles_list)} rows out of {initial_count} "
          f"by removing invalid protein sequences and sequences of length > {MAX_PROTEIN_SEQUENCE_LEN}.")

    full_dataset = SmilesProteinDataset(smiles_list, protein_list, SMILES_TOKEN_DICT, SMILES_MAX_LEN, sorted_tokens, get_esm_embedding)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=ESM_EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    patience_counter = 0

    print("Starting training...")
    for epoch in range(MAX_EPOCHS):
        print(f"--- Epoch {epoch+1}/{MAX_EPOCHS} ---")
        train_loss = train_model(model, train_dataloader, optimizer, device, len(SMILES_TOKEN_DICT))
        val_loss = validate_model(model, val_dataloader, device, len(SMILES_TOKEN_DICT))
        print(f"Epoch {epoch+1}/{MAX_EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print(f"Validation loss improved. Saving model to {MODEL_SAVE_PATH}")
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{PATIENCE}")

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print(f"\nTraining complete. Model saved to '{MODEL_SAVE_PATH}'.")

    best_model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=ESM_EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    best_model.load_state_dict(torch.load(MODEL_SAVE_PATH))

    generate_new_molecules(best_model, SMILES_TOKEN_DICT, target_protein_sequence, num_molecules=5, device=device, latent_dim=LATENT_DIM, sorted_tokens=sorted_tokens)

Mounted at /content/drive
Using device: cuda
Loading ESM-2 model: esm2_t12_35M_UR50D to CPU...


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ESM model loaded successfully to CPU with embedding dimension: 480
Successfully loaded 3046040 data points from '/content/drive/MyDrive/Drug_Database/bindingdb_dataset.csv'.
Filtered dataset: Retained 2479480 rows out of 3046040 by removing invalid protein sequences and sequences of length > 1000.
Starting training...
--- Epoch 1/100 ---


Training:   0%|          | 2/27120 [00:15<57:25:25,  7.62s/it, batch_loss=5.06e+4]