<a href="https://colab.research.google.com/github/Swayamprakashpatel/DD/blob/main/Conditional_VAE_Drug_Generator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install rdkit
!pip install torch



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolToImage
from IPython.display import Image, display
import numpy as np
import warnings
import os
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# Suppress RDKit warnings for a cleaner output
warnings.filterwarnings("ignore")

# --- Global constants for SMILES tokenization ---
# This dictionary has been updated to include a padding character '`'
# which is used to ensure all SMILES strings have a consistent length.
SMILES_TOKEN_DICT = {
    '#': 0, '(': 1, ')': 2, '+': 3, '-': 4, '/': 5, '=': 6, '[': 7, ']': 8,
    'C': 9, 'c': 10, 'H': 11, 'N': 12, 'O': 13, 'S': 14, 'F': 15, 'I': 16,
    'P': 17, 'n': 18, 's': 19, 'B': 20, 'Cl': 21, 'Br': 22, 'r': 23, 'l': 24,
    'o': 25, 'p': 26, 'se': 27, 'Se': 28, 'a': 29, 'i': 30, '`': 31
}

SMILES_MAX_LEN = 1000
HIDDEN_DIM = 256
LATENT_DIM = 64
NUM_LAYERS = 2
MAX_EPOCHS = 100
BATCH_SIZE = 256
PATIENCE = 15
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = 'best_model.pt'

# --- Data Preprocessing and PyTorch Dataset ---

def smiles_to_one_hot(smiles, max_len=SMILES_MAX_LEN):
    """Encodes a SMILES string into a one-hot vector with padding."""
    token_to_idx = SMILES_TOKEN_DICT
    if len(smiles) > max_len:
        # Truncate and add end-of-string token
        smiles = smiles[:max_len-1] + '`'
    else:
        # Pad with end-of-string token
        smiles += '`' * (max_len - len(smiles))

    one_hot_vector = np.zeros((max_len, len(token_to_idx)))
    for i, token in enumerate(list(smiles)):
        if token in token_to_idx:
            one_hot_vector[i, token_to_idx[token]] = 1
    return one_hot_vector

def get_protein_embedding(seq):
    """
    A simple, placeholder protein embedding function.
    In a real-world scenario, you would use a more sophisticated method,
    like a pre-trained protein language model.
    """
    protein_vocab = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10,
                     'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
    embedding = np.zeros(len(protein_vocab))
    for amino_acid in seq:
        if amino_acid in protein_vocab:
            embedding[protein_vocab[amino_acid]] = 1
    return embedding

class SmilesProteinDataset(Dataset):
    """Dataset for loading SMILES and Protein sequence pairs."""
    def __init__(self, smiles_list, protein_list):
        self.smiles_list = smiles_list
        self.protein_list = protein_list
        # The protein embedding dimension is determined from the first sequence in the list
        self.protein_embedding_dim = len(get_protein_embedding(self.protein_list[0]))

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        protein = self.protein_list[idx]
        smiles_one_hot = smiles_to_one_hot(smiles)
        protein_embedding = get_protein_embedding(protein)
        return torch.tensor(smiles_one_hot, dtype=torch.float32), torch.tensor(protein_embedding, dtype=torch.float32)

# --- Conditional VAE Model Definition ---

class ConditionalVAE(nn.Module):
    def __init__(self, smiles_vocab_size, protein_embedding_dim, hidden_dim, latent_dim, smiles_max_len, num_layers):
        super(ConditionalVAE, self).__init__()

        self.smiles_max_len = smiles_max_len
        self.smiles_vocab_size = smiles_vocab_size
        self.protein_embedding_dim = protein_embedding_dim

        # Encoder: takes smiles and protein embedding as input
        self.encoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.encoder_protein_mlp = nn.Sequential(
            nn.Linear(protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder: takes latent vector and protein embedding as input
        self.decoder_gru_input_mlp = nn.Sequential(
            nn.Linear(latent_dim + protein_embedding_dim, hidden_dim),
            nn.ReLU()
        )
        self.decoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_output = nn.Linear(hidden_dim, smiles_vocab_size)

    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from N(mu, var)"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder(self, smiles_one_hot, protein_embedding):
        _, hidden_smiles = self.encoder_rnn(smiles_one_hot)
        hidden_protein = self.encoder_protein_mlp(protein_embedding).unsqueeze(0).repeat(hidden_smiles.size(0), 1, 1)
        combined_hidden = hidden_smiles + hidden_protein
        final_hidden_state = combined_hidden[-1, :, :]
        mu = self.fc_mu(final_hidden_state)
        logvar = self.fc_logvar(final_hidden_state)
        return mu, logvar

    def decoder(self, z, protein_embedding):
        z_with_protein = torch.cat((z, protein_embedding), dim=1)
        initial_hidden = self.decoder_gru_input_mlp(z_with_protein).unsqueeze(0)
        initial_hidden = initial_hidden.repeat(self.decoder_rnn.num_layers, 1, 1)
        decoder_input = torch.zeros(z.size(0), self.smiles_max_len, self.smiles_vocab_size).to(z.device)
        decoder_input[:, 0, SMILES_TOKEN_DICT['#']] = 1
        output, _ = self.decoder_rnn(decoder_input, initial_hidden)
        output = self.fc_output(output)
        return output

    def forward(self, smiles_one_hot, protein_embedding):
        mu, logvar = self.encoder(smiles_one_hot, protein_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_smiles = self.decoder(z, protein_embedding)
        return reconstructed_smiles, mu, logvar

# --- Training Loop and Generation ---

def vae_loss(recon_x, x, mu, logvar):
    """VAE loss function: BCE + KL divergence."""
    # Ensure dimensions are compatible for BCE calculation
    recon_x_flat = recon_x.view(-1, len(SMILES_TOKEN_DICT))
    x_flat = x.view(-1, len(SMILES_TOKEN_DICT))
    BCE = nn.functional.binary_cross_entropy_with_logits(recon_x_flat, x_flat, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_model(model, dataloader, optimizer, device):
    """Trains the VAE model for one epoch with a progress bar."""
    model.train()
    total_loss = 0

    for smiles_oh, protein_emb in tqdm(dataloader, desc="Training"):
        smiles_oh = smiles_oh.to(device)
        protein_emb = protein_emb.to(device)

        optimizer.zero_grad()

        reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
        loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def validate_model(model, dataloader, device):
    """Evaluates the model on the validation set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for smiles_oh, protein_emb in tqdm(dataloader, desc="Validation"):
            smiles_oh = smiles_oh.to(device)
            protein_emb = protein_emb.to(device)

            reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
            loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def generate_new_molecules(model, protein_seq, num_molecules=5, device="cpu"):
    """Generates and visualizes new molecules for a given protein sequence."""
    model.eval()
    idx_to_token = {v: k for k, v in SMILES_TOKEN_DICT.items()}
    generated_molecules = []

    # Prepare the protein embedding for a single sequence
    protein_embedding = get_protein_embedding(protein_seq)
    protein_embedding = torch.tensor(protein_embedding, dtype=torch.float32).unsqueeze(0).to(device)

    print(f"\nAttempting to generate {num_molecules} molecules for protein sequence: {protein_seq}")
    with torch.no_grad():
        for i in range(num_molecules):
            # Sample a vector from the latent space
            z = torch.randn(1, LATENT_DIM).to(device)
            # Use the decoder to generate a one-hot encoded SMILES string
            generated_output = model.decoder(z, protein_embedding)

            # Use softmax to get probabilities for each token and then select the most likely one
            probabilities = nn.functional.softmax(generated_output, dim=-1)
            predicted_indices = torch.argmax(probabilities, dim=-1).squeeze(0).cpu().numpy()

            generated_smiles = ""
            for token_idx in predicted_indices:
                token = idx_to_token.get(token_idx, '')
                # Stop decoding when a padding token is encountered
                if token == '`':
                    break
                generated_smiles += token

            # Validate and display the generated SMILES string as an image
            mol = Chem.MolFromSmiles(generated_smiles)
            if mol is not None:
                print(f"Generated molecule (valid SMILES): {generated_smiles}")
                img = MolToImage(mol)
                display(img)
            else:
                print(f"Generated molecule (invalid SMILES): {generated_smiles}")

            generated_molecules.append(generated_smiles)

    return generated_molecules

if __name__ == '__main__':
    # Determine the device to use (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define an example protein sequence to predict for
    target_protein_sequence = 'MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL'

    # --- Training Process ---
    try:
        df = pd.read_csv('final_output_15_2_25.csv')
        smiles_list = df['SMILES'].tolist()
        protein_list = df['TARGET_SEQUENCE'].tolist()
    except FileNotFoundError:
        print("Error: The data file 'final_output_15_2_25.csv' was not found. Please ensure it is uploaded.")
        exit()

    # Create dataset and dataloader
    full_dataset = SmilesProteinDataset(smiles_list, protein_list)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize the model and optimizer
    protein_embedding_dim = len(get_protein_embedding(target_protein_sequence))
    model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    patience_counter = 0

    print("Starting training...")
    for epoch in range(MAX_EPOCHS):
        print(f"--- Epoch {epoch+1}/{MAX_EPOCHS} ---")
        train_loss = train_model(model, train_dataloader, optimizer, device)
        val_loss = validate_model(model, val_dataloader, device)

        print(f"Epoch {epoch+1}/{MAX_EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Check for improvement and save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print(f"Validation loss improved. Saving model to {MODEL_SAVE_PATH}")
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{PATIENCE}")

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print(f"\nTraining complete. Model saved to '{MODEL_SAVE_PATH}'.")

    # --- Generation Process ---
    # We load the best model to ensure we are using the one with the best performance
    best_model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    best_model.load_state_dict(torch.load(MODEL_SAVE_PATH))

    # Generate new molecules using the loaded model
    generate_new_molecules(best_model, target_protein_sequence, num_molecules=5, device=device)


In [19]:
## Code Update Summary
#---
#Here is a summary of the updates made to the code. You can copy and paste this into a cell at the top of your notebook.

### Key Changes
#* **Improved SMILES Tokenization:** The `smiles_to_one_hot` function has been updated to correctly handle multi-character SMILES tokens (e.g., `Cl`, `Br`). This prevents the model from incorrectly interpreting these as two separate tokens, which was a significant issue in the previous version.
#* **Expanded Token Dictionary:** The `SMILES_TOKEN_DICT` now includes more common two-character elements such as `Mg`, `Fe`, and `Zn` for more accurate SMILES parsing.
#* **Hyperparameter Tuning Notes:** The model's validation loss reaching a plateau and triggering early stopping after 20-25 epochs suggests that it has either converged or is starting to overfit. To potentially improve the final loss, you can try adjusting hyperparameters like the `LEARNING_RATE` or `BATCH_SIZE`.

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolToImage
from IPython.display import Image, display
import numpy as np
import warnings
import os
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# Suppress RDKit warnings for a cleaner output
warnings.filterwarnings("ignore")

# --- Global constants for SMILES tokenization ---
# This dictionary has been updated to handle multi-character tokens correctly.
# The `get_smiles_tokens` function will use this for tokenization.
SMILES_TOKEN_DICT = {
    '#': 0, '(': 1, ')': 2, '+': 3, '-': 4, '/': 5, '=': 6, '[': 7, ']': 8,
    'C': 9, 'c': 10, 'H': 11, 'N': 12, 'O': 13, 'S': 14, 'F': 15, 'I': 16,
    'P': 17, 'n': 18, 's': 19, 'B': 20, 'Cl': 21, 'Br': 22, 'r': 23, 'l': 24,
    'o': 25, 'p': 26, 'se': 27, 'Se': 28, 'a': 29, 'i': 30, '`': 31,
    'Mg': 32, 'Fe': 33, 'Zn': 34, 'Si': 35
}

# Ensure `SMILES_TOKEN_DICT` is sorted by key length descending for correct tokenization
sorted_tokens = sorted(SMILES_TOKEN_DICT.keys(), key=len, reverse=True)

SMILES_MAX_LEN = 1000
HIDDEN_DIM = 256
LATENT_DIM = 64
NUM_LAYERS = 2
MAX_EPOCHS = 100
BATCH_SIZE = 256
PATIENCE = 15
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = 'best_model.pt'

# --- Data Preprocessing and PyTorch Dataset ---
def get_smiles_tokens(smiles_string):
    """
    Tokenizes a SMILES string, handling multi-character tokens.
    """
    tokens = []
    i = 0
    while i < len(smiles_string):
        matched = False
        for token in sorted_tokens:
            if smiles_string.startswith(token, i):
                tokens.append(token)
                i += len(token)
                matched = True
                break
        if not matched:
            i += 1
    return tokens

def smiles_to_one_hot(smiles, max_len=SMILES_MAX_LEN):
    """Encodes a SMILES string into a one-hot vector with padding."""
    token_to_idx = SMILES_TOKEN_DICT
    tokens = get_smiles_tokens(smiles)

    if len(tokens) > max_len:
        tokens = tokens[:max_len]

    # Pad with end-of-string token
    padded_tokens = tokens + ['`'] * (max_len - len(tokens))

    one_hot_vector = np.zeros((max_len, len(token_to_idx)), dtype=np.float32)
    for i, token in enumerate(padded_tokens):
        if token in token_to_idx:
            one_hot_vector[i, token_to_idx[token]] = 1
    return one_hot_vector

def get_protein_embedding(seq):
    """
    A simple, placeholder protein embedding function.
    In a real-world scenario, you would use a more sophisticated method,
    like a pre-trained protein language model (e.g., from the ProtTrans or ESM family)
    for better performance on complex protein sequences.
    """
    protein_vocab = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10,
                     'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
    embedding = np.zeros(len(protein_vocab), dtype=np.float32)
    for amino_acid in seq:
        if amino_acid in protein_vocab:
            embedding[protein_vocab[amino_acid]] = 1
    return embedding

class SmilesProteinDataset(Dataset):
    """Dataset for loading SMILES and Protein sequence pairs."""
    def __init__(self, smiles_list, protein_list):
        self.smiles_list = smiles_list
        self.protein_list = protein_list
        # The protein embedding dimension is determined from the first sequence in the list
        self.protein_embedding_dim = len(get_protein_embedding(self.protein_list[0]))

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        protein = self.protein_list[idx]
        smiles_one_hot = smiles_to_one_hot(smiles)
        protein_embedding = get_protein_embedding(protein)
        return torch.tensor(smiles_one_hot, dtype=torch.float32), torch.tensor(protein_embedding, dtype=torch.float32)

# --- Conditional VAE Model Definition ---

class ConditionalVAE(nn.Module):
    def __init__(self, smiles_vocab_size, protein_embedding_dim, hidden_dim, latent_dim, smiles_max_len, num_layers):
        super(ConditionalVAE, self).__init__()

        self.smiles_max_len = smiles_max_len
        self.smiles_vocab_size = smiles_vocab_size
        self.protein_embedding_dim = protein_embedding_dim

        # Encoder: takes smiles and protein embedding as input
        self.encoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.encoder_protein_mlp = nn.Sequential(
            nn.Linear(protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder: takes latent vector and protein embedding as input
        self.decoder_gru_input_mlp = nn.Sequential(
            nn.Linear(latent_dim + protein_embedding_dim, hidden_dim),
            nn.ReLU()
        )
        self.decoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_output = nn.Linear(hidden_dim, smiles_vocab_size)

    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from N(mu, var)"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder(self, smiles_one_hot, protein_embedding):
        _, hidden_smiles = self.encoder_rnn(smiles_one_hot)
        hidden_protein = self.encoder_protein_mlp(protein_embedding).unsqueeze(0).repeat(hidden_smiles.size(0), 1, 1)
        combined_hidden = hidden_smiles + hidden_protein
        final_hidden_state = combined_hidden[-1, :, :]
        mu = self.fc_mu(final_hidden_state)
        logvar = self.fc_logvar(final_hidden_state)
        return mu, logvar

    def decoder(self, z, protein_embedding):
        z_with_protein = torch.cat((z, protein_embedding), dim=1)
        initial_hidden = self.decoder_gru_input_mlp(z_with_protein).unsqueeze(0)
        initial_hidden = initial_hidden.repeat(self.decoder_rnn.num_layers, 1, 1)
        decoder_input = torch.zeros(z.size(0), self.smiles_max_len, self.smiles_vocab_size).to(z.device)
        decoder_input[:, 0, SMILES_TOKEN_DICT['#']] = 1
        output, _ = self.decoder_rnn(decoder_input, initial_hidden)
        output = self.fc_output(output)
        return output

    def forward(self, smiles_one_hot, protein_embedding):
        mu, logvar = self.encoder(smiles_one_hot, protein_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_smiles = self.decoder(z, protein_embedding)
        return reconstructed_smiles, mu, logvar

# --- Training Loop and Generation ---

def vae_loss(recon_x, x, mu, logvar):
    """VAE loss function: BCE + KL divergence."""
    # Ensure dimensions are compatible for BCE calculation
    recon_x_flat = recon_x.view(-1, len(SMILES_TOKEN_DICT))
    x_flat = x.view(-1, len(SMILES_TOKEN_DICT))
    BCE = nn.functional.binary_cross_entropy_with_logits(recon_x_flat, x_flat, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_model(model, dataloader, optimizer, device):
    """Trains the VAE model for one epoch with a progress bar."""
    model.train()
    total_loss = 0

    for smiles_oh, protein_emb in tqdm(dataloader, desc="Training"):
        smiles_oh = smiles_oh.to(device)
        protein_emb = protein_emb.to(device)

        optimizer.zero_grad()

        reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
        loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def validate_model(model, dataloader, device):
    """Evaluates the model on the validation set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for smiles_oh, protein_emb in tqdm(dataloader, desc="Validation"):
            smiles_oh = smiles_oh.to(device)
            protein_emb = protein_emb.to(device)

            reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
            loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def generate_new_molecules(model, protein_seq, num_molecules=5, device="cpu"):
    """Generates and visualizes new molecules for a given protein sequence."""
    model.eval()
    idx_to_token = {v: k for k, v in SMILES_TOKEN_DICT.items()}
    generated_molecules = []

    # Prepare the protein embedding for a single sequence
    protein_embedding = get_protein_embedding(protein_seq)
    protein_embedding = torch.tensor(protein_embedding, dtype=torch.float32).unsqueeze(0).to(device)

    print(f"\nAttempting to generate {num_molecules} molecules for protein sequence: {protein_seq}")
    with torch.no_grad():
        for i in range(num_molecules):
            # Sample a vector from the latent space
            z = torch.randn(1, LATENT_DIM).to(device)
            # Use the decoder to generate a one-hot encoded SMILES string
            generated_output = model.decoder(z, protein_embedding)

            # Use softmax to get probabilities for each token and then select the most likely one
            probabilities = nn.functional.softmax(generated_output, dim=-1)
            predicted_indices = torch.argmax(probabilities, dim=-1).squeeze(0).cpu().numpy()

            generated_smiles = ""
            for token_idx in predicted_indices:
                token = idx_to_token.get(token_idx, '')
                # Stop decoding when a padding token is encountered
                if token == '`':
                    break
                generated_smiles += token

            # Validate and display the generated SMILES string as an image
            mol = Chem.MolFromSmiles(generated_smiles)
            if mol is not None:
                print(f"Generated molecule (valid SMILES): {generated_smiles}")
                #
                img = MolToImage(mol)
                display(img)
            else:
                print(f"Generated molecule (invalid SMILES): {generated_smiles}")

            generated_molecules.append(generated_smiles)

    return generated_molecules

if __name__ == '__main__':
    # Determine the device to use (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define an example protein sequence to predict for
    target_protein_sequence = 'MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL'

    # --- Training Process ---
    try:
        df = pd.read_csv('final_output_15_2_25.csv')
        smiles_list = df['SMILES'].tolist()
        protein_list = df['TARGET_SEQUENCE'].tolist()
    except FileNotFoundError:
        print("Error: The data file 'final_output_15_2_25.csv' was not found. Please ensure it is uploaded.")
        exit()

    # Create dataset and dataloader
    full_dataset = SmilesProteinDataset(smiles_list, protein_list)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize the model and optimizer
    protein_embedding_dim = len(get_protein_embedding(target_protein_sequence))
    model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    patience_counter = 0

    print("Starting training...")
    for epoch in range(MAX_EPOCHS):
        print(f"--- Epoch {epoch+1}/{MAX_EPOCHS} ---")
        train_loss = train_model(model, train_dataloader, optimizer, device)
        val_loss = validate_model(model, val_dataloader, device)

        print(f"Epoch {epoch+1}/{MAX_EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Check for improvement and save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print(f"Validation loss improved. Saving model to {MODEL_SAVE_PATH}")
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{PATIENCE}")

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print(f"\nTraining complete. Model saved to '{MODEL_SAVE_PATH}'.")

    # --- Generation Process ---
    # We load the best model to ensure we are using the one with the best performance
    best_model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    best_model.load_state_dict(torch.load(MODEL_SAVE_PATH))

    # Generate new molecules using the loaded model
    generate_new_molecules(best_model, target_protein_sequence, num_molecules=5, device=device)


Using device: cuda
Starting training...
--- Epoch 1/100 ---


Training: 100%|██████████| 86/86 [00:34<00:00,  2.53it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.82it/s]


Epoch 1/100, Train Loss: 2337.1510, Validation Loss: 272.6223
Validation loss improved. Saving model to best_model.pt
--- Epoch 2/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.56it/s]
Validation: 100%|██████████| 37/37 [00:08<00:00,  4.52it/s]


Epoch 2/100, Train Loss: 259.5445, Validation Loss: 252.9564
Validation loss improved. Saving model to best_model.pt
--- Epoch 3/100 ---


Training: 100%|██████████| 86/86 [00:34<00:00,  2.52it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.72it/s]


Epoch 3/100, Train Loss: 248.9015, Validation Loss: 249.7019
Validation loss improved. Saving model to best_model.pt
--- Epoch 4/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.54it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.67it/s]


Epoch 4/100, Train Loss: 246.4244, Validation Loss: 245.5211
Validation loss improved. Saving model to best_model.pt
--- Epoch 5/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.53it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.81it/s]


Epoch 5/100, Train Loss: 243.8195, Validation Loss: 246.3971
Validation loss did not improve. Patience: 1/15
--- Epoch 6/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.57it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.78it/s]


Epoch 6/100, Train Loss: 243.0959, Validation Loss: 242.9432
Validation loss improved. Saving model to best_model.pt
--- Epoch 7/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.55it/s]
Validation: 100%|██████████| 37/37 [00:08<00:00,  4.51it/s]


Epoch 7/100, Train Loss: 241.4271, Validation Loss: 242.5930
Validation loss improved. Saving model to best_model.pt
--- Epoch 8/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.53it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.70it/s]


Epoch 8/100, Train Loss: 242.1518, Validation Loss: 246.1286
Validation loss did not improve. Patience: 1/15
--- Epoch 9/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.56it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.78it/s]


Epoch 9/100, Train Loss: 241.9855, Validation Loss: 242.9735
Validation loss did not improve. Patience: 2/15
--- Epoch 10/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.54it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.76it/s]


Epoch 10/100, Train Loss: 241.5349, Validation Loss: 241.5820
Validation loss improved. Saving model to best_model.pt
--- Epoch 11/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.56it/s]
Validation: 100%|██████████| 37/37 [00:07<00:00,  4.81it/s]


Epoch 11/100, Train Loss: 240.1783, Validation Loss: 248.1282
Validation loss did not improve. Patience: 1/15
--- Epoch 12/100 ---


Training: 100%|██████████| 86/86 [00:33<00:00,  2.53it/s]
Validation:  92%|█████████▏| 34/37 [00:07<00:00,  4.47it/s]


KeyboardInterrupt: 

In [20]:
#All Periodic Table Elements in Tokeninzation Character

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Draw import MolToImage
from IPython.display import Image, display
import numpy as np
import warnings
import os
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# Suppress RDKit warnings for a cleaner output
warnings.filterwarnings("ignore")

# --- Global constants for a comprehensive SMILES vocabulary ---
def create_comprehensive_smiles_vocab():
    """
    Creates a comprehensive SMILES token vocabulary based on the periodic table,
    numbers, and common SMILES characters.
    """
    periodic_table_elements = [
        'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si',
        'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
        'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb',
        'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
        'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
        'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg',
        'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np',
        'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg',
        'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og'
    ]
    # Common SMILES single characters
    single_char_tokens = [
        '(', ')', '[', ']', '+', '-', '=', '#', '.', ':', '/', '\\',
        'c', 'n', 'o', 's', '`', '@'
    ]
    # Numbers
    numbers = [str(i) for i in range(10)]

    # Combine all tokens and remove duplicates
    all_tokens = set(periodic_table_elements + single_char_tokens + numbers)
    # Add start and padding tokens
    all_tokens.add('#') # Start token
    all_tokens.add('`') # Padding token

    # Sort tokens by length in descending order to handle multi-character tokens correctly
    sorted_tokens = sorted(list(all_tokens), key=len, reverse=True)
    # Create the dictionary
    smiles_token_dict = {token: i for i, token in enumerate(sorted_tokens)}

    return smiles_token_dict, sorted_tokens

SMILES_TOKEN_DICT, sorted_tokens = create_comprehensive_smiles_vocab()
SMILES_MAX_LEN = 1000
HIDDEN_DIM = 256
LATENT_DIM = 64
NUM_LAYERS = 2
MAX_EPOCHS = 100
BATCH_SIZE = 256
PATIENCE = 15
LEARNING_RATE = 1e-3
MODEL_SAVE_PATH = 'best_model.pt'

# --- Data Preprocessing and PyTorch Dataset ---
def get_smiles_tokens(smiles_string, special_tokens):
    """
    Tokenizes a SMILES string using a predefined list of special tokens.
    """
    tokens = []
    i = 0
    while i < len(smiles_string):
        matched = False
        for token in special_tokens:
            if smiles_string.startswith(token, i):
                tokens.append(token)
                i += len(token)
                matched = True
                break
        if not matched:
            tokens.append(smiles_string[i])
            i += 1
    return tokens

def smiles_to_one_hot(smiles, token_to_idx, max_len, sorted_tokens):
    """Encodes a SMILES string into a one-hot vector with padding."""
    tokens = get_smiles_tokens(smiles, sorted_tokens)

    if len(tokens) > max_len:
        tokens = tokens[:max_len]

    # Pad with padding token
    padded_tokens = tokens + ['`'] * (max_len - len(tokens))

    one_hot_vector = np.zeros((max_len, len(token_to_idx)), dtype=np.float32)
    for i, token in enumerate(padded_tokens):
        if token in token_to_idx:
            one_hot_vector[i, token_to_idx[token]] = 1
    return one_hot_vector

def get_protein_embedding(seq):
    """
    A simple, placeholder protein embedding function.
    In a real-world scenario, you would use a more sophisticated method,
    like a pre-trained protein language model (e.g., from the ProtTrans or ESM family)
    for better performance on complex protein sequences.
    """
    protein_vocab = {'A': 0, 'C': 1, 'D': 2, 'E': 3, 'F': 4, 'G': 5, 'H': 6, 'I': 7, 'K': 8, 'L': 9, 'M': 10,
                     'N': 11, 'P': 12, 'Q': 13, 'R': 14, 'S': 15, 'T': 16, 'V': 17, 'W': 18, 'Y': 19}
    embedding = np.zeros(len(protein_vocab), dtype=np.float32)
    for amino_acid in seq:
        if amino_acid in protein_vocab:
            embedding[protein_vocab[amino_acid]] = 1
    return embedding

class SmilesProteinDataset(Dataset):
    """Dataset for loading SMILES and Protein sequence pairs."""
    def __init__(self, smiles_list, protein_list, smiles_token_dict, smiles_max_len, sorted_tokens):
        self.smiles_list = smiles_list
        self.protein_list = protein_list
        self.smiles_token_dict = smiles_token_dict
        self.smiles_max_len = smiles_max_len
        self.sorted_tokens = sorted_tokens
        # The protein embedding dimension is determined from the first sequence in the list
        self.protein_embedding_dim = len(get_protein_embedding(self.protein_list[0]))

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        protein = self.protein_list[idx]
        smiles_one_hot = smiles_to_one_hot(smiles, self.smiles_token_dict, self.smiles_max_len, self.sorted_tokens)
        protein_embedding = get_protein_embedding(protein)
        return torch.tensor(smiles_one_hot, dtype=torch.float32), torch.tensor(protein_embedding, dtype=torch.float32)

# --- Conditional VAE Model Definition ---

class ConditionalVAE(nn.Module):
    def __init__(self, smiles_vocab_size, protein_embedding_dim, hidden_dim, latent_dim, smiles_max_len, num_layers):
        super(ConditionalVAE, self).__init__()

        self.smiles_max_len = smiles_max_len
        self.smiles_vocab_size = smiles_vocab_size
        self.protein_embedding_dim = protein_embedding_dim

        # Encoder: takes smiles and protein embedding as input
        self.encoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.encoder_protein_mlp = nn.Sequential(
            nn.Linear(protein_embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder: takes latent vector and protein embedding as input
        self.decoder_gru_input_mlp = nn.Sequential(
            nn.Linear(latent_dim + protein_embedding_dim, hidden_dim),
            nn.ReLU()
        )
        self.decoder_rnn = nn.GRU(
            input_size=smiles_vocab_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc_output = nn.Linear(hidden_dim, smiles_vocab_size)

    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from N(mu, var)"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoder(self, smiles_one_hot, protein_embedding):
        _, hidden_smiles = self.encoder_rnn(smiles_one_hot)
        hidden_protein = self.encoder_protein_mlp(protein_embedding).unsqueeze(0).repeat(hidden_smiles.size(0), 1, 1)
        combined_hidden = hidden_smiles + hidden_protein
        final_hidden_state = combined_hidden[-1, :, :]
        mu = self.fc_mu(final_hidden_state)
        logvar = self.fc_logvar(final_hidden_state)
        return mu, logvar

    def decoder(self, z, protein_embedding):
        z_with_protein = torch.cat((z, protein_embedding), dim=1)
        initial_hidden = self.decoder_gru_input_mlp(z_with_protein).unsqueeze(0)
        initial_hidden = initial_hidden.repeat(self.decoder_rnn.num_layers, 1, 1)
        decoder_input = torch.zeros(z.size(0), self.smiles_max_len, self.smiles_vocab_size).to(z.device)
        output, _ = self.decoder_rnn(decoder_input, initial_hidden)
        output = self.fc_output(output)
        return output

    def forward(self, smiles_one_hot, protein_embedding):
        mu, logvar = self.encoder(smiles_one_hot, protein_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_smiles = self.decoder(z, protein_embedding)
        return reconstructed_smiles, mu, logvar

# --- Training Loop and Generation ---

def vae_loss(recon_x, x, mu, logvar, smiles_vocab_size):
    """VAE loss function: BCE + KL divergence."""
    # Ensure dimensions are compatible for BCE calculation
    recon_x_flat = recon_x.view(-1, smiles_vocab_size)
    x_flat = x.view(-1, smiles_vocab_size)
    BCE = nn.functional.binary_cross_entropy_with_logits(recon_x_flat, x_flat, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train_model(model, dataloader, optimizer, device, smiles_vocab_size):
    """Trains the VAE model for one epoch with a progress bar."""
    model.train()
    total_loss = 0

    for smiles_oh, protein_emb in tqdm(dataloader, desc="Training"):
        smiles_oh = smiles_oh.to(device)
        protein_emb = protein_emb.to(device)

        optimizer.zero_grad()

        reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
        loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def validate_model(model, dataloader, device, smiles_vocab_size):
    """Evaluates the model on the validation set."""
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for smiles_oh, protein_emb in tqdm(dataloader, desc="Validation"):
            smiles_oh = smiles_oh.to(device)
            protein_emb = protein_emb.to(device)

            reconstructed_smiles, mu, logvar = model(smiles_oh, protein_emb)
            loss = vae_loss(reconstructed_smiles, smiles_oh, mu, logvar, smiles_vocab_size)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader.dataset)
    return avg_loss

def generate_new_molecules(model, smiles_token_dict, protein_seq, num_molecules, device, latent_dim):
    """Generates and visualizes new molecules for a given protein sequence."""
    model.eval()
    idx_to_token = {v: k for k, v in smiles_token_dict.items()}
    generated_molecules = []

    # Prepare the protein embedding for a single sequence
    protein_embedding = get_protein_embedding(protein_seq)
    protein_embedding = torch.tensor(protein_embedding, dtype=torch.float32).unsqueeze(0).to(device)

    print(f"\nAttempting to generate {num_molecules} molecules for protein sequence: {protein_seq}")
    with torch.no_grad():
        for i in range(num_molecules):
            # Sample a vector from the latent space
            z = torch.randn(1, latent_dim).to(device)
            # Use the decoder to generate a one-hot encoded SMILES string
            generated_output = model.decoder(z, protein_embedding)

            # Use softmax to get probabilities for each token and then select the most likely one
            probabilities = nn.functional.softmax(generated_output, dim=-1)
            predicted_indices = torch.argmax(probabilities, dim=-1).squeeze(0).cpu().numpy()

            generated_smiles = ""
            for token_idx in predicted_indices:
                token = idx_to_token.get(token_idx, '')
                # Stop decoding when a padding token is encountered
                if token == '`':
                    break
                generated_smiles += token

            # Validate and display the generated SMILES string as an image
            mol = Chem.MolFromSmiles(generated_smiles)
            if mol is not None:
                print(f"Generated molecule (valid SMILES): {generated_smiles}")
                img = MolToImage(mol)
                display(img)
            else:
                print(f"Generated molecule (invalid SMILES): {generated_smiles}")

            generated_molecules.append(generated_smiles)

    return generated_molecules

if __name__ == '__main__':
    # Determine the device to use (GPU if available, otherwise CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define an example protein sequence to predict for
    target_protein_sequence = 'MGNASNDSQSEDCETRQWLPPGESPAISSVMFSAGVLGNLIALALLARRWRGDVGCSAGRRSSLSLFHVLVTELVFTDLLGTCLISPVVLASYARNQTLVALAPESRACTYFAFAMTFFSLATMLMLFAMALERYLSIGHPYFYQRRVSRSGGLAVLPVIYAVSLLFCSLPLLDYGQYVQYCPGTWCFIRHGRTAYLQLYATLLLLLIVSVLACNFSVILNLIRMHRRSRRSRCGPSLGSGRGGPGARRRGERVSMAEETDHLILLAIMTITFAVCSLPFTIFAYMNETSSRKEKWDLQALRFLSINSIIDPWVFAILRPPVLRLMRSVLCCRISLRTQDATQTSCSTQSDASKQADL'

    # --- Training Process ---
    try:
        df = pd.read_csv('final_output_15_2_25.csv')
        smiles_list = df['SMILES'].tolist()
        protein_list = df['TARGET_SEQUENCE'].tolist()
    except FileNotFoundError:
        print("Error: The data file 'final_output_15_2_25.csv' was not found. Please ensure it is uploaded.")
        exit()

    # Create dataset and dataloader
    full_dataset = SmilesProteinDataset(smiles_list, protein_list, SMILES_TOKEN_DICT, SMILES_MAX_LEN, sorted_tokens)
    train_size = int(0.7 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize the model and optimizer
    protein_embedding_dim = len(get_protein_embedding(target_protein_sequence))
    model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_val_loss = float('inf')
    patience_counter = 0

    print("Starting training...")
    for epoch in range(MAX_EPOCHS):
        print(f"--- Epoch {epoch+1}/{MAX_EPOCHS} ---")
        train_loss = train_model(model, train_dataloader, optimizer, device, len(SMILES_TOKEN_DICT))
        val_loss = validate_model(model, val_dataloader, device, len(SMILES_TOKEN_DICT))

        print(f"Epoch {epoch+1}/{MAX_EPOCHS}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Check for improvement and save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            print(f"Validation loss improved. Saving model to {MODEL_SAVE_PATH}")
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            patience_counter += 1
            print(f"Validation loss did not improve. Patience: {patience_counter}/{PATIENCE}")

        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break

    print(f"\nTraining complete. Model saved to '{MODEL_SAVE_PATH}'.")

    # --- Generation Process ---
    # We load the best model to ensure we are using the one with the best performance
    best_model = ConditionalVAE(
        smiles_vocab_size=len(SMILES_TOKEN_DICT),
        protein_embedding_dim=protein_embedding_dim,
        hidden_dim=HIDDEN_DIM,
        latent_dim=LATENT_DIM,
        smiles_max_len=SMILES_MAX_LEN,
        num_layers=NUM_LAYERS
    ).to(device)
    best_model.load_state_dict(torch.load(MODEL_SAVE_PATH))

    # Generate new molecules using the loaded model
    generate_new_molecules(best_model, SMILES_TOKEN_DICT, target_protein_sequence, num_molecules=5, device=device, latent_dim=LATENT_DIM)


Using device: cuda
Starting training...
--- Epoch 1/100 ---


Training: 100%|██████████| 86/86 [01:02<00:00,  1.38it/s]
Validation: 100%|██████████| 37/37 [00:17<00:00,  2.08it/s]


Epoch 1/100, Train Loss: 9646.2322, Validation Loss: 360.7238
Validation loss improved. Saving model to best_model.pt
--- Epoch 2/100 ---


Training:  10%|█         | 9/86 [00:06<00:55,  1.38it/s]


KeyboardInterrupt: 