In [None]:


!pip install torch rdkit pypi selfies pandas

Collecting rdkit
  Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting pypi
  Downloading pypi-2.1.tar.gz (997 bytes)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting selfies
  Downloading selfies-2.2.0-py3-none-any.whl.metadata (14 kB)
Downloading rdkit-2025.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading selfies-2.2.0-py3-none-any.whl (36 kB)
Building wheels for collected packages: pypi
  Building wheel for pypi (setup.py) ... [?25l[?25hdone
  Created wheel for pypi: filename=pypi-2.1-py3-none-any.whl size=1334 sha256=4e5b0a170d1688b2c11ba23ac8d05b68d85139acc0883cbd94a838996cedabd6
  Stored in directory: /root/.cache/pip/wheels/28/4c/49/00cdce1e7a68a48810e9203391f80f4c7344a5e4ad9d4d6649
Successfully built pypi
Installing collected packages: pypi, selfies, rdkit
Successfully installed pyp

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import pandas as pd
import numpy as np
import selfies as sf
import random
import math
from tqdm.auto import tqdm

from rdkit import Chem
from rdkit.Chem import QED, Crippen, Descriptors

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# Cell 3: Download and load the dataset
url = "https://media.githubusercontent.com/media/molecularsets/moses/master/data/dataset_v1.csv"
df = pd.read_csv(url)
smiles_list = df['SMILES'].tolist()

# Convert SMILES to SELFIES [cite: 65]
selfies_list = [sf.encoder(smi) for smi in tqdm(smiles_list, desc="Encoding SELFIES")]

# Cell 4: Build Vocabulary
def build_vocab(selfies_list):
    """Builds a vocabulary and token mappings from a list of SELFIES strings."""
    tokens = set()
    for s in selfies_list:
        tokens.update(sf.split_selfies(s))

    vocab = sorted(list(tokens))
    vocab.insert(0, '<pad>') # Padding token
    vocab.insert(1, '<sos>') # Start of sequence
    vocab.insert(2, '<eos>') # End of sequence

    token_to_int = {token: i for i, token in enumerate(vocab)}
    int_to_token = {i: token for i, token in enumerate(vocab)}

    return vocab, token_to_int, int_to_token

vocab, token_to_int, int_to_token = build_vocab(selfies_list)
vocab_size = len(vocab)
print(f"Vocabulary size: {vocab_size}")

# Cell 5: Create a PyTorch Dataset
class SELFIESDataset(Dataset):
    """PyTorch Dataset for SELFIES sequences."""
    def __init__(self, selfies_list, token_to_int, max_len=128):
        self.selfies_list = selfies_list
        self.token_to_int = token_to_int
        self.max_len = max_len
        self.pad_token_id = token_to_int['<pad>']
        self.sos_token_id = token_to_int['<sos>']
        self.eos_token_id = token_to_int['<eos>']

    def __len__(self):
        return len(self.selfies_list)

    def __getitem__(self, idx):
        selfies = sf.split_selfies(self.selfies_list[idx])
        tokens = ['<sos>'] + list(selfies) + ['<eos>']

        # Truncate if longer than max_len
        if len(tokens) > self.max_len:
            tokens = tokens[:self.max_len]

        # Convert to integers
        token_ids = [self.token_to_int.get(token, self.pad_token_id) for token in tokens]

        # Pad sequence
        padding_len = self.max_len - len(token_ids)
        token_ids += [self.pad_token_id] * padding_len

        return torch.tensor(token_ids, dtype=torch.long)

# Hyperparameters
MAX_LEN = 128
BATCH_SIZE = 16

dataset = SELFIESDataset(selfies_list, token_to_int, max_len=MAX_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Encoding SELFIES:   0%|          | 0/1936962 [00:00<?, ?it/s]

Vocabulary size: 29


In [None]:
# Cell 6: Transformer Generator [cite: 66, 95]
class TransformerGenerator(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=512, max_seq_length=128):
        super(TransformerGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers,
                                           num_decoder_layers, dim_feedforward, batch_first=True)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt):
        src_emb = self.embedding(src) * math.sqrt(self.d_model) + self.positional_encoding[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model) + self.positional_encoding[:, :tgt.size(1), :]

        # Create masks
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(device)

        output = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
        return self.fc_out(output)

# Modified CNNDiscriminator to accept pre-embedded inputs
class CNNDiscriminator(nn.Module):
    def __init__(self, vocab_size, d_model=256, max_seq_length=128):
        super(CNNDiscriminator, self).__init__()
        # The embedding layer is still part of the model's parameters
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 100, (k, d_model)) for k in [2, 3, 4, 5]
        ])
        self.fc = nn.Linear(400, 1)

    def forward(self, embedded_x):
        # The forward pass now starts with the already embedded input
        # The input should have shape: (batch_size, seq_len, d_model)
        embedded = embedded_x.unsqueeze(1) # (batch_size, 1, seq_len, d_model)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(c, c.size(2)).squeeze(2) for c in conved]
        cat = torch.cat(pooled, 1)
        return self.fc(cat)

In [None]:
# Define Hyperparameters for GAN Training
D_MODEL = 128     # Dimension of the model
LR = 1e-4         # Learning rate
LAMBDA_GP = 10.0  # Gradient penalty coefficient
N_EPOCHS = 50      # Number of epochs for GAN training

In [None]:
# Cell 8: WGAN-GP Gradient Penalty Calculation (Your improved version is kept)
def compute_gradient_penalty(discriminator, real_samples_embedded, fake_samples_embedded):
    """Calculates the gradient penalty for WGAN-GP on embedded samples."""
    alpha = torch.randn(real_samples_embedded.size(0), 1, 1).to(device)
    alpha = alpha.expand_as(real_samples_embedded)

    interpolates = (alpha * real_samples_embedded + ((1 - alpha) * fake_samples_embedded)).requires_grad_(True)
    d_interpolates = discriminator(interpolates)

    fake = torch.autograd.Variable(torch.Tensor(real_samples_embedded.shape[0], 1).fill_(1.0), requires_grad=False).to(device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# Cell 9: Corrected Training Loop
# (Re-initialize models and optimizers before running this cell)
generator = TransformerGenerator(vocab_size, d_model=D_MODEL, max_seq_length=MAX_LEN).to(device)
discriminator = CNNDiscriminator(vocab_size, d_model=D_MODEL, max_seq_length=MAX_LEN).to(device)
optimizer_g = optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.9))
optimizer_d = optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.9))
criterion_g = nn.CrossEntropyLoss(ignore_index=token_to_int['<pad>'])

print("Starting Corrected WGAN-GP Training...")
for epoch in range(N_EPOCHS):
    for i, real_seq in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{N_EPOCHS}")):
        real_seq = real_seq.to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_d.zero_grad()

        # --- FIX: Generate from random noise to create NOVEL molecules ---
        # The noise_input acts as the latent vector (source for generation)
        noise_input = torch.randint(0, vocab_size, real_seq.shape, device=device)
        # We still use teacher-forcing with the real sequence as the target
        fake_output_logits = generator(noise_input, real_seq[:, :-1])
        fake_seq = torch.argmax(fake_output_logits, dim=-1)

        # --- FIX: Pad the fake sequence to match the real sequence length ---
        pad_token_id = token_to_int['<pad>']
        # Pad on the right side (last dimension) with one padding token
        fake_seq_padded = F.pad(fake_seq, (0, 1), "constant", pad_token_id) # Shape: (batch_size, 128)

        # Now embed the sequences of matching lengths
        real_seq_embedded = discriminator.embedding(real_seq)
        fake_seq_embedded = discriminator.embedding(fake_seq_padded.detach())


        real_validity = discriminator(real_seq_embedded)
        fake_validity = discriminator(fake_seq_embedded)

        gradient_penalty = compute_gradient_penalty(discriminator, real_seq_embedded.detach(), fake_seq_embedded.detach())
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + LAMBDA_GP * gradient_penalty

        d_loss.backward()
        optimizer_d.step()

        # -----------------
        #  Train Generator
        # -----------------
        if i % 5 == 0:
            optimizer_g.zero_grad()

            # --- FIX: Ensure generator is trained to generate from noise ---
            noise_input_g = torch.randint(0, vocab_size, real_seq.shape, device=device)
            fake_output_logits_g = generator(noise_input_g, real_seq[:, :-1])
            fake_seq_g = torch.argmax(fake_output_logits_g, dim=-1)

            # --- KEPT IMPROVEMENT: Embed fake sequences for discriminator input ---
            fake_seq_g_embedded = discriminator.embedding(fake_seq_g)
            adv_loss = -torch.mean(discriminator(fake_seq_g_embedded))

            # Supervised Loss (helps with stability)
            sup_loss = criterion_g(fake_output_logits_g.view(-1, vocab_size), real_seq[:, 1:].reshape(-1))

            g_loss = adv_loss + sup_loss
            g_loss.backward()
            optimizer_g.step()

    if 'd_loss' in locals():
      print(f"[Epoch {epoch+1}/{N_EPOCHS}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

⏳ No WGAN checkpoint found. Starting training from scratch.
Starting WGAN-GP Training...


Epoch 1/50:   0%|          | 0/15625 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Cell 10: Multi-Objective Reward Function [cite: 54, 101]
def calculate_reward(selfies_string):
    """Calculates a weighted reward for a molecule based on QED, logP, and SA."""
    try:
        smi = sf.decoder(selfies_string)
        mol = Chem.MolFromSmiles(smi)
        if mol is None: return 0.0

        qed = QED.qed(mol)
        logp = Crippen.MolLogP(mol)

        # A simple SA score-like penalty based on molecular weight
        mw = Descriptors.MolWt(mol)
        sa_penalty = max(0, (mw - 350) / 100) # Penalize high MW

        # Define weights for each property
        w1, w2, w3 = 0.5, 0.25, 0.25
        reward = (w1 * qed) + (w2 * (1 - abs(logp - 2.5) / 2.5)) - (w3 * sa_penalty)
        return max(0, reward) # Ensure reward is non-negative
    except:
        return 0.0

# Cell 11: RL Fine-tuning Loop (Efficient Batched Version)
print("\nStarting RL Fine-tuning (Batched)...")
optimizer_g_rl = optim.Adam(generator.parameters(), lr=LR / 10)

N_RL_STEPS = 100
for step in range(N_RL_STEPS):
    generator.train()
    optimizer_g_rl.zero_grad()

    # --- EFFICIENT BATCHED GENERATION ---
    # Store log probabilities and generated selfies for each item in the batch
    batch_log_probs = [[] for _ in range(BATCH_SIZE)]
    generated_selfies = ["" for _ in range(BATCH_SIZE)]

    # Keep track of which sequences in the batch are still being generated
    active_mask = [True] * BATCH_SIZE

    # Start with a batch of <sos> tokens
    input_seq = torch.full((BATCH_SIZE, 1), token_to_int['<sos>'], device=device, dtype=torch.long)

    # Autoregressively generate for MAX_LEN steps
    for t in range(MAX_LEN - 1):
        if not any(active_mask): # Stop if all sequences in the batch are finished
            break

        output_logits = generator(input_seq, input_seq)
        probs = F.softmax(output_logits[:, -1, :], dim=-1)
        next_token_ids = torch.multinomial(probs, 1)

        # Update sequences for all active items in the batch
        for i in range(BATCH_SIZE):
            if active_mask[i]:
                token_id = next_token_ids[i].item()

                # If <eos> is generated, mark the sequence as finished
                if token_id == token_to_int['<eos>']:
                    active_mask[i] = False
                else:
                    # Append log probability for the chosen token
                    log_prob = torch.log(probs[i, token_id])
                    batch_log_probs[i].append(log_prob)

                    # Append the token to the generated SELFIES string
                    generated_selfies[i] += int_to_token[token_id]

        # Append the new tokens to the input sequence for the next step
        input_seq = torch.cat([input_seq, next_token_ids], dim=1)

    # --- Calculate rewards and policy loss for the entire batch ---
    batch_rewards = [calculate_reward(s) for s in generated_selfies]

    policy_loss = 0
    for i in range(BATCH_SIZE):
        # Only calculate loss if the sequence is valid (has log probs)
        if batch_log_probs[i]:
            # The REINFORCE algorithm update rule
            policy_loss += -torch.stack(batch_log_probs[i]).sum() * batch_rewards[i]

    # Average the loss over the batch and perform backpropagation
    if BATCH_SIZE > 0:
        policy_loss /= BATCH_SIZE
        policy_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
        optimizer_g_rl.step()

    if (step + 1) % 10 == 0:
        print(f"[RL Step {step+1}/{N_RL_STEPS}] [Avg Reward: {np.mean(batch_rewards):.4f}]")

In [None]:
# Cell 12: Generation and Evaluation Function
def generate_molecules(generator, num_molecules=5):
    """Generates molecules from the trained generator and evaluates them."""
    generator.eval()
    print("\n--- Generating Final Molecules ---")
    for i in range(num_molecules):
        with torch.no_grad():
            input_seq = torch.tensor([[token_to_int['<sos>']]], device=device)
            generated_selfies = ""
            for _ in range(MAX_LEN - 1):
                output_logits = generator(input_seq, input_seq)
                next_token_id = torch.argmax(output_logits[:, -1, :], dim=-1).unsqueeze(0)

                if next_token_id.item() == token_to_int['<eos>']: break

                token = int_to_token[next_token_id.item()]
                generated_selfies += token
                input_seq = torch.cat([input_seq, next_token_id], dim=1)

        smi = sf.decoder(generated_selfies)
        reward = calculate_reward(generated_selfies)
        print(f"Molecule {i+1}: {smi} | Reward: {reward:.3f}")

# Generate some molecules from the final model
generate_molecules(generator, num_molecules=10)

#STOP HERE


In [None]:
# --- In your Colab Notebook ---

import json

# 1. Save the model's learned weights (the state_dict)
torch.save(generator.state_dict(), 'generator_model.pth')

# 2. Save the tokenizer vocabulary
tokenizer_data = {
    'token_to_int': token_to_int,
    'int_to_token': {int(k): v for k, v in int_to_token.items()} # Ensure keys are JSON compatible
}
with open('tokenizer.json', 'w') as f:
    json.dump(tokenizer_data, f)

# 3. Download the files
from google.colab import files
files.download('generator_model.pth')
files.download('tokenizer.json')