# Unconditional Molecular Generation

**Definition**: Molecule Generation is to generate diverse, novel molecules that has desirable chemical properties. These properties are measured by oracle functions. A machine learning task first learns the molecular characteristics from a large set of molecules where each is evaluated through the oracles. Then, from the learned distribution, we can obtain novel candidates.

**Impact**: As the entire chemical space is far too large to screen for each target, high through screening can only be restricted to a set of existing molecule libraries. Many novel drug candidates are thus usually omitted. Machine learning that can generate novel molecules obeying some pre-defined optimal properties can circumvent this problem and obtain novel classes of candidates.


## Install Dependencies
To run the project smoothly, you’ll need to install several Python libraries. Below is a list of required dependencies along with their purpose:

- **TensorFlow**: A popular deep learning framework used for building and training neural networks.
- **RDKit**: A collection of cheminformatics and machine learning tools.
- **Scikit-learn**: A library for machine learning in Python, providing simple and efficient tools for data mining and data analysis.
- **FCD-Torch**: A library for fast computation of molecular fingerprints.
- **numpy**: A fundamental package for scientific computing in Python, providing support for arrays, matrices, and a wide range of mathematical functions.
- **pandas**: A library for data manipulation and analysis, providing data structures like DataFrames.

In [9]:
! pip install selfies pympler  -q 

## Import Necessary libraries

In [10]:
import tensorflow as tf
import selfies as sf
import pandas as pd 
import numpy as np 
import random
import keras
import math
import os
import re
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Embedding, MultiHeadAttention, LayerNormalization, Dropout, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

## Define metrics, optimizer, Global Variables

In [11]:
# Model parameters
BATCH_SIZE = 1024
EMBEDDING_DIM = 256
SEQ_LENGTH = 50
NUM_HEADS = 8
DFF = 1024
# Other variables
DATASET_PATH = "/kaggle/input/drug-discovery/data.csv"
LR = 1e-4

metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
optimizer = tf.keras.optimizers.Adam(learning_rate=LR)

I0000 00:00:1756477455.286006      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


## Set seeds for reproducibility

In [12]:
# Set seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

## Defining Custom Masked Loss Function

In [13]:
@keras.saving.register_keras_serializable()
def masked_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction='none'
    )
    loss = loss_fn(y_true, y_pred)
    mask = tf.cast(tf.not_equal(y_true, 0), tf.float32)
    loss = loss * mask
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)


## Defining Model Architecture


In [14]:
@keras.saving.register_keras_serializable()
class LongTermMemory(tf.keras.layers.Layer):
    """
    Long-term memory layer that compresses and refines input representations.
    """
    def __init__(self, units, activation=tf.keras.activations.silu, **kwargs):
        super(LongTermMemory, self).__init__(**kwargs)
        """
        A long-term memory layer that compresses and refines input representations.
        Args:
            units (int): The number of output units.
            activation (callable): The activation function to use.
        Returns:
            None
        """
        super(LongTermMemory, self).__init__()
        self.units = units
        self.activation = activation
        self.name = "Long_Term_Memory"

        # Define layers
        self.fc1 = Dense(self.units, activation=self.activation)
        self.fc2 = Dense(self.units * 2, activation=self.activation)
        self.fc3 = Dense(self.units, activation=self.activation)
    
    def call(self, inputs, mask=None):
        x = self.fc1(inputs)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

@keras.saving.register_keras_serializable()
class Memory(tf.keras.layers.Layer):
    """
    Memory module with long-term + persistent memory integration.
    """
    def __init__(self, embedding_dim, sequence_length, activation=tf.keras.activations.silu, **kwargs):
        super(Memory, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.activation = activation
        self.name = "Memory"

    
        # Query transformation
        self.LMWq = Dense(units=self.embedding_dim, activation=self.activation, use_bias=False, name="Query_Transformation_Layer")
        # Long-term memory
        self.LM = LongTermMemory(self.embedding_dim, activation=self.activation)

        # Persistent memory vector (trainable, sequence-independent)
        self.persistent_memory = self.add_weight(
            shape=(1, 1, self.embedding_dim),
            initializer="glorot_uniform",
            trainable=True,
            name="Persistent_Memory"
        )

        # Normalization after concatenation
        self.norm = LayerNormalization(epsilon=1e-6, name="Memory_Normalization_Layer")

    def call(self, inputs, mask=None):
        q = self.LMWq(inputs)
        ltm_out = self.LM(q)

        batch_size = tf.shape(inputs)[0]
        persistent = tf.tile(self.persistent_memory, [batch_size, self.sequence_length, 1])

        concat = tf.concat([inputs, ltm_out, persistent], axis=-1)
        norm = self.norm(concat)

        return norm


@keras.saving.register_keras_serializable()
class PositionwiseFeedforward(tf.keras.layers.Layer):
    """
    Standard FFN (expansion + projection) used in Transformers.
    """
    def __init__(self, embedding_dim, dff, **kwargs):
        super(PositionwiseFeedforward, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.dff = dff
        self.name = "Positionwise_Feedforward"
        self.dense1 = Dense(self.dff, activation='relu')
        self.dense2 = Dense(self.embedding_dim)

    def call(self, x):
        return self.dense2(self.dense1(x))


In [15]:
@keras.saving.register_keras_serializable()
class Titans(tf.keras.layers.Layer):
    """
    Transformer-based memory-augmented architecture with masking support.
    """

    def __init__(self, embedding_dim, sequence_length, num_heads, dff, vocab_size, 
                 rate=0.4, mask_zero=True, **kwargs):     
        """
        Initializes the Titans layer.
        Args:
            embedding_dim (int): Dimensionality of the embedding space.
            sequence_length (int): Length of the input sequences.
            num_heads (int): Number of attention heads.
            dff (int): Dimensionality of the feedforward network.
            vocab_size (int): Total number of words in the vocabulary.
            rate (float): Dropout rate.
            mask_zero (bool): Whether to mask padding tokens.
        Returns:
            None
        """

        super(Titans, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.dff = dff
        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.supports_masking = mask_zero
        self.mask_zero = mask_zero
        self.rate = rate
        self.name = "Titans"

    
    def build(self, input_shape):
        #Initializes layers (embedding, memory, attention, FFN, normalization, gating).
        
        # Embedding + positional encoding
        self.embedding_layer = Embedding(
            input_dim=self.vocab_size,   # should be vocab size, not seq length
            output_dim=self.embedding_dim,
            mask_zero=self.mask_zero
        )

        self.position_embedding = Embedding(
            input_dim=self.sequence_length,
            output_dim=self.embedding_dim
        )

        # Memory + Transformer components
        self.memory = Memory(self.embedding_dim, self.sequence_length)
        self.mha = MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embedding_dim)
        self.ffn = PositionwiseFeedforward(self.embedding_dim * 3, self.dff)

        self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = Dropout(self.rate)

        # Gating
        self.gate = Dense(units=self.embedding_dim * 3, activation='sigmoid')
        self.modulation_layer = Dense(units=self.embedding_dim * 3)

        # Final linear layer
        self.final_layer = Dense(units=self.vocab_size)

    def create_causal_mask(self,seq_len):
        """
        Creates a causal mask for the given sequence length.
        Args:
            seq_len (int): Length of the sequence.
        Returns:
            tf.Tensor: Causal mask tensor.
        """
        return tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)

    def combine_masks(self,pad_mask, seq_len):
        """
        Combines padding and causal masks.
        Args:
            pad_mask (tf.Tensor): Padding mask tensor.
            seq_len (int): Length of the sequence.
        Returns:
            tf.Tensor: Combined mask tensor.
        """
        causal_mask = self.create_causal_mask(seq_len)  # (seq_len, seq_len)
        causal_mask = causal_mask[tf.newaxis, tf.newaxis, :, :]  # (1,1,seq,seq)
    
        # pad mask → (batch, 1, 1, seq)
        pad_mask = pad_mask[:, tf.newaxis, tf.newaxis, :]
    
        # Combine (broadcast AND)
        return tf.cast(tf.logical_and(tf.cast(pad_mask, tf.bool),
                                  tf.cast(causal_mask, tf.bool)), tf.float32)

    def call(self, inputs, mask=None, training=False):

        # Embedding
        x = self.embedding_layer(inputs)
    
        # Padding mask
        if mask is None:
            mask = self.embedding_layer.compute_mask(inputs)   # (batch, seq_len)


        # Attention mask
        attn_mask = self.combine_masks(mask, self.sequence_length)  # (batch,1,seq,seq)
        pad_mask = tf.cast(mask[:, :, tf.newaxis], x.dtype)  # (batch, seq_len, 1) for element-wise masking
    
        # Positional encoding
        positions = tf.range(start=0, limit=self.sequence_length, delta=1)
        pos_emb = self.position_embedding(positions)
        x = tf.add(x, pos_emb)

        # Memory augmentation
        memory_output = self.memory(x, mask=mask)
        memory_output *= pad_mask   # ensure padding stays zero

        # Multi-head attention
        attn_output = self.mha(
        memory_output,   # query
        memory_output,   # value
        memory_output,   # key
        attention_mask=attn_mask,
        training=training
        )

        attn_output *= pad_mask   # re-mask after MHA

        # Feedforward
        ffn_output = self.ffn(attn_output)
        ffn_output = self.layernorm(ffn_output)
        ffn_output = self.dropout(ffn_output, training=training)
        ffn_output *= pad_mask   # re-mask after FFN + norm

        # Skip connection
        skip = tf.add(memory_output, ffn_output)
        skip *= pad_mask   # ensure skip preserves masking

        # Gating
        linear_gating = self.gate(skip)
        modulated_output = self.modulation_layer(linear_gating)
        output = tf.multiply(linear_gating, modulated_output)
        output *= pad_mask   # final mask application

        # Final projection

        return self.final_layer(output)   # logits over vocab
    
        



In [16]:
# Define a custom model using Titans Transformer-based architecture
@keras.saving.register_keras_serializable()
class Model(tf.keras.Model):
    def __init__(self, embedding_dim, sequence_length, num_heads, dff, vocab_size, **kwargs):
        """
        Custom Transformer-based model using Titans library.
        Args:
            embedding_dim (int): Size of the word embedding.
            sequence_length (int): Maximum length of input sequences.
            num_heads (int): Number of attention heads.
            dff (int): Dimension of the feed-forward network.
            vocab_size (int): Vocabulary size.
        """
        super(Model, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.sequence_length = sequence_length
        self.num_heads = num_heads
        self.dff = dff
        self.vocab_size = vocab_size
        self.name = "Model"
        # Builds the model layers, including the Titans module.
        self.titans = Titans(embedding_dim=self.embedding_dim,
                             sequence_length=self.sequence_length,
                             num_heads=self.num_heads,
                             dff=self.dff,
                             vocab_size=self.vocab_size)


    def call(self, inputs):
        """
        Defines forward pass of the model.
        """
        x = self.titans(inputs, mask=None)
        return x
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "sequence_length": self.sequence_length,
            "num_heads": self.num_heads,
            "dff": self.dff,
            "vocab_size": self.vocab_size
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)



## Reading the dataset

In [17]:
df = pd.read_csv(DATASET_PATH)

In [18]:
sample_df = df.sample(50000)

In [19]:
# Split into Train and Val
train_df, val_df = train_test_split(sample_df, test_size=0.1)

## Preprocess Input Strings

In [20]:
def tokenize_selfies(smiles: str) -> list:
    """Tokenizes a SMILES string into SELFIES tokens.
    Args:
        smiles (str): A SMILES string.
    Returns:
        list: A list of SELFIES tokens.
    """

    try:
        selfies = sf.encoder(smiles)
    except Exception:
        selfies = ""   # skip invalid ones
    return list(sf.split_selfies(selfies))


def build_vocab_from_dataframe(df: pd.DataFrame) -> list:
    """
    Builds a vocabulary of SELFIES tokens directly from a DataFrame of SMILES.
    Adds <start>, <end>, and <UNK> special tokens.
    Args:
        df (pd.DataFrame): A DataFrame containing SMILES strings.
    Returns:
        list: A list of unique SELFIES tokens.
    """
    vocab = {"<start>", "<end>"}  # start with special tokens

    for smiles in df['smiles']:
        tokens = tokenize_selfies(smiles)
        vocab.update(tokens)

    return sorted(list(vocab))


def tokenizer_initialize_from_dataframe(df: pd.DataFrame) -> tuple:
    """
    Initializes a Keras tokenizer using vocab built from DataFrame.
    Args:
        df (pd.DataFrame): A DataFrame containing SMILES strings.
    Returns:
        tuple: A tuple containing the tokenizer, vocabulary size, and all tokens.
    """
    all_tokens = build_vocab_from_dataframe(df)

    tokenizer = Tokenizer(oov_token="<UNK>", filters='', lower=False)
    tokenizer.fit_on_texts(all_tokens)

    vocab_size = len(tokenizer.word_index) + 1 # +1 for padding (index 0)
    return tokenizer, vocab_size, all_tokens


def sequence_generator(df: pd.DataFrame, tokenizer, max_seq_length: int, seq_padding: int = 1) -> tuple:
    """Generates sequences of token IDs from SELFIES with post-padding.
    Args:
        df (pd.DataFrame): A DataFrame containing SMILES strings.
        tokenizer: A Keras tokenizer fitted on SELFIES tokens.
        max_seq_length (int): The maximum sequence length.
        seq_padding (int): The amount of padding to apply.
    Yields:
        tuple: A tuple containing the input and target sequences.
    """
    for smiles in df['smiles']:
        tokens = ["<start>"]
        selfies = tokenize_selfies(smiles)
        tokens.extend(selfies)
        tokens.append("<end>")

        token_ids = tokenizer.texts_to_sequences([tokens])[0]

        target_len = max_seq_length + seq_padding

        # Post-pad sequences
        if len(token_ids) < target_len:
            token_ids = token_ids + [0] * (target_len - len(token_ids))
        elif len(token_ids) > target_len:
            token_ids = token_ids[-target_len:]  # truncate from the left if too long

        x = token_ids[:-1]
        y = token_ids[1:]
        yield x, y


def create_selfies_dataset(sample_df: pd.DataFrame, max_seq_length: int, batch_size: int = 256, buffer_size: int = 10000, seq_padding: int = 1) -> tuple:
    """
    Creates a TensorFlow dataset from SMILES strings using SELFIES encoding.
    Args:
        sample_df (pd.DataFrame): A DataFrame containing SMILES strings.
        max_seq_length (int): The maximum sequence length.
        batch_size (int): The batch size for the dataset.
        buffer_size (int): The buffer size for shuffling.
        seq_padding (int): The amount of padding to apply.
    Returns:
        tuple: A tuple containing the dataset, tokenizer, vocabulary size, maximum sequence length, and all tokens.
    """
    tokenizer, vocab_size, all_tokens = tokenizer_initialize_from_dataframe(sample_df)

    output_signature = (
        tf.TensorSpec(shape=(max_seq_length,), dtype=tf.int32),  # Input sequence
        tf.TensorSpec(shape=(max_seq_length,), dtype=tf.int32)   # Target sequence
    )

    dataset = tf.data.Dataset.from_generator(
        lambda: sequence_generator(sample_df, tokenizer, max_seq_length),
        output_signature=output_signature
    )

    dataset = dataset.shuffle(buffer_size).repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset, tokenizer, vocab_size, max_seq_length, all_tokens


def dataset_gen(dataframe: pd.DataFrame, tokenizer, max_seq_length: int, vocab_size: int, batch_size: int = 128, buffer_size: int = 10000, seq_padding: int = 1) -> tf.data.Dataset:
    """Generates a TensorFlow dataset from a DataFrame using a given tokenizer.
    Args:
        dataframe (pd.DataFrame): A DataFrame containing SMILES strings.
        tokenizer: A Keras tokenizer fitted on SELFIES tokens.
        max_seq_length (int): The maximum sequence length.
        vocab_size (int): The vocabulary size.
        batch_size (int): The batch size for the dataset.
        buffer_size (int): The buffer size for shuffling.
        seq_padding (int): The amount of padding to apply.
    Returns:
        tf.data.Dataset: A TensorFlow dataset.
    """
    output_signature = (
        tf.TensorSpec(shape=(max_seq_length,), dtype=tf.int32),
        tf.TensorSpec(shape=(max_seq_length,), dtype=tf.int32)
    )

    dataset = tf.data.Dataset.from_generator(
        lambda: sequence_generator(dataframe, tokenizer, max_seq_length),
        output_signature=output_signature
    )

    dataset = dataset.shuffle(buffer_size).repeat().batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset


# Create the training dataset
dataset, tokenizer, vocab_size, max_seq_length, all_tokens = create_selfies_dataset(
    train_df, max_seq_length=SEQ_LENGTH , batch_size=BATCH_SIZE
)

# Create the validation dataset
val_dataset = dataset_gen(val_df, tokenizer, max_seq_length, vocab_size, batch_size=BATCH_SIZE)


In [21]:
# Get token IDs for special tokens
start_token_id = tokenizer.word_index["<start>"]
end_token_id = tokenizer.word_index["<end>"]
print(f"<start>: {start_token_id}, <end>: {end_token_id}")

<start>: 4, <end>: 3


## Initialize Model

In [22]:
model = Model(embedding_dim=EMBEDDING_DIM, sequence_length=SEQ_LENGTH, num_heads=NUM_HEADS, dff=DFF, vocab_size=vocab_size)



In [23]:
# Build by calling with input shape
inputs = tf.keras.Input(shape=(SEQ_LENGTH,))
outputs = model(inputs)

drug_discovery_model = tf.keras.Model(inputs, outputs, name="DrugDiscoveryModel")
drug_discovery_model.compile(loss=masked_loss, optimizer=optimizer, metrics=metrics)
drug_discovery_model.summary()

## Training

In [24]:


STEPS_PER_EPOCH = math.ceil(len(train_df) / BATCH_SIZE) # returns int 
VAL_STEPS = math.ceil(len(val_df) / BATCH_SIZE)

checkpoint_path = "best_model.weights.h5"

callbacks = [
    ModelCheckpoint(
        filepath=checkpoint_path,
        monitor="sparse_categorical_accuracy",
        mode="max",
        save_best_only=True,
        save_weights_only=True,
        verbose=1
    ),
    EarlyStopping(
        monitor="sparse_categorical_accuracy",
        mode="max",
        patience=10,
        verbose=1,
        restore_best_weights=True
    ),
    # Define LR scheduler
    ReduceLROnPlateau(
        monitor="val_loss",     # quantity to be monitored
        factor=0.5,             # new_lr = lr * factor
        patience=3,             # wait for 3 epochs before reducing LR
        min_lr=1e-6,            # lower bound on LR
        verbose=1               # print LR updates
    )

]

history = drug_discovery_model.fit(
    dataset,
    epochs=1000,
    validation_data=val_dataset,
    steps_per_epoch=int(STEPS_PER_EPOCH),
    validation_steps=int(VAL_STEPS),
    callbacks=callbacks
)

drug_discovery_model.load_weights(checkpoint_path)

# Save the model
drug_discovery_model.save("drug_discovery_model.keras")

Epoch 1/1000


I0000 00:00:1756477494.189630     113 service.cc:148] XLA service 0x7aa788006570 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1756477494.192792     113 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
W0000 00:00:1756477495.052366     113 assert_op.cc:38] Ignoring Assert operator compile_loss/masked_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
I0000 00:00:1756477495.391004     113 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1756477511.681804     113 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 765ms/step - loss: 3.0411 - sparse_categorical_accuracy: 0.2813

W0000 00:00:1756477554.994068     112 assert_op.cc:38] Ignoring Assert operator compile_loss/masked_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert



Epoch 1: sparse_categorical_accuracy improved from -inf to 0.29760, saving model to best_model.weights.h5
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 1s/step - loss: 3.0290 - sparse_categorical_accuracy: 0.2816 - val_loss: 2.0677 - val_sparse_categorical_accuracy: 0.3271 - learning_rate: 1.0000e-04
Epoch 2/1000
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 772ms/step - loss: 2.0431 - sparse_categorical_accuracy: 0.3298
Epoch 2: sparse_categorical_accuracy improved from 0.29760 to 0.33621, saving model to best_model.weights.h5
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 1s/step - loss: 2.0421 - sparse_categorical_accuracy: 0.3300 - val_loss: 1.8721 - val_sparse_categorical_accuracy: 0.3549 - learning_rate: 1.0000e-04
Epoch 3/1000
[1m44/44[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 740ms/step - loss: 1.8343 - sparse_categorical_accuracy: 0.3637
Epoch 3: sparse_categorical_accuracy improved from 0.33621 to 0.37424, s

## Inference

In [25]:
def top_p_logits_batch(logits: tf.Tensor, top_p: float = 1.0) -> tf.Tensor:
    """
    Apply top-p (nucleus) filtering to a batch of logits.
    Args:
        logits: [batch, vocab]
        top_p: cumulative probability threshold
    Returns:
        [batch, vocab]: Filtered logits.
    """
    if top_p >= 1.0:
        return logits

    new_logits = []
    for logit in logits:  # loop over batch
        sorted_indices = tf.argsort(logit, direction='DESCENDING')
        sorted_logits = tf.gather(logit, sorted_indices)
        sorted_probs = tf.nn.softmax(sorted_logits)
        cumulative_probs = tf.cumsum(sorted_probs)

        # Mask tokens outside top_p
        mask = cumulative_probs > top_p
        mask = tf.concat([[False], mask[:-1]], axis=0)  # keep first above top_p
        # Set masked logits to -inf
        sorted_logits = tf.where(mask, tf.fill(tf.shape(sorted_logits), float('-inf')), sorted_logits)
        # Scatter back to original order
        new_logit = tf.scatter_nd(
            indices=tf.expand_dims(sorted_indices, axis=-1),
            updates=sorted_logits,
            shape=tf.shape(logit, out_type=tf.int32)
        )
        new_logits.append(new_logit)

    return tf.stack(new_logits, axis=0)


def sample_from_logits(logits: tf.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0) -> tf.Tensor:
    """
    Sample token IDs from logits with temperature scaling, top-k, and top-p filtering.
    Args:
        logits: [batch, vocab]
        temperature: scaling factor
        top_k: number of top tokens to consider
        top_p: cumulative probability threshold
    Returns:
        [batch]: sampled token IDs
    """


    # Scale by temperature
    logits = logits / temperature

    # Top-k filtering
    if top_k > 0:
        values, _ = tf.math.top_k(logits, k=top_k)
        min_values = values[:, -1, tf.newaxis]
        logits = tf.where(logits < min_values, tf.constant(-np.inf, dtype=logits.dtype), logits)
                          

    # Top-p
    if top_p < 1.0:
       logits = top_p_logits_batch(logits, top_p)
    # Debug: print logits statistics
    #print("sample_from_logits:")
    #print("  logits min/max:", tf.reduce_min(logits).numpy(), tf.reduce_max(logits).numpy())
    #print("  logits mean:", tf.reduce_mean(logits).numpy())

    # Convert to probabilities safely and sample
    sampled_ids = tf.random.categorical(logits, num_samples=1)[:, 0]
    return sampled_ids


def generate_drug_batch(
    seed_texts,
    model,
    tokenizer,
    max_length,
    next_words=30,
    temperature=1.0,
    top_k=0,
    top_p=1.0,
    end_token_id=None,
    min_length_before_eos=20,
):
    """
    Generate sequences from a model with autoregressive decoding.

    Args:
        seed_texts (list[str]): Starting strings for each sequence in the batch.
        model (tf.keras.Model): Trained model for generation.
        tokenizer: Tokenizer with `texts_to_sequences` & `index_word`.
        max_length (int): Maximum sequence length.
        next_words (int): Maximum number of tokens to generate.
        temperature (float): Softmax temperature scaling.
        top_k (int): Top-k sampling cutoff.
        top_p (float): Nucleus (top-p) sampling cutoff.
        end_token_id (int): ID of EOS token.
        min_length_before_eos (int): Minimum tokens before EOS allowed.

    Returns:
        list[str]: Generated sequences.
    """
    assert next_words <= max_length, "next_words must be <= max_length"

    batch_size = len(seed_texts)

    # Convert seeds to padded token sequences
    token_lists = tokenizer.texts_to_sequences(seed_texts)
    token_lists = pad_sequences(token_lists, maxlen=max_length, padding="post")

    finished = [False] * batch_size
    step = 0

    while step < next_words:
        # Forward pass -> logits for next token
        predicted_logits = model.predict(token_lists, verbose=0)  # [B, T, V]
        logits = predicted_logits[:, -1, :]  # take last step logits [B, V]

        # Mask out banned tokens
        banned_tokens = [
            tokenizer.word_index.get("<UNK>"),
            tokenizer.word_index.get("<start>"),
        ]
        if end_token_id is not None and step < min_length_before_eos:
            banned_tokens.append(end_token_id)

        banned_tokens = [t for t in banned_tokens if t is not None]

        if banned_tokens:
            vocab_size = logits.shape[-1]
            mask = tf.zeros(vocab_size, dtype=logits.dtype)
            updates = tf.constant([-float("inf")] * len(banned_tokens), dtype=logits.dtype)
            mask = tf.tensor_scatter_nd_update(
                mask,
                indices=[[tid] for tid in banned_tokens],
                updates=updates,
            )
            logits = logits + mask  # broadcast to [B, V]

      
        sampled_ids = sample_from_logits(logits, temperature, top_k, top_p)

        for i, token_id in enumerate(sampled_ids.numpy()):
            if finished[i]:
                continue

            if end_token_id is not None and token_id == end_token_id:
                finished[i] = True
                continue

            word = tokenizer.index_word.get(token_id, None)
            if word and word not in ("<UNK>", "<start>"):
                seed_texts[i] += " " + word

        # Roll sequence window forward
        token_lists = tf.concat(
            [token_lists[:, 1:], tf.expand_dims(sampled_ids, axis=-1)], axis=1
        )

        if all(finished):
            break

        step += 1

    return seed_texts


In [26]:
# Define prompts for generation
prompts = ["<start>" for i in range(1000)]

In [27]:
# Generate drug candidates
generated = generate_drug_batch(
    prompts,
    model,
    tokenizer,
    max_length=50,
    next_words=50,
    temperature=1,
    top_k=30,
    top_p=0.95,
    end_token_id=end_token_id
)

## Evaluation

In [33]:
train_dataset = []
for smile in train_df['smiles'].sample(1000):
    train_dataset.append(smile)
print(len(train_dataset))

1000


In [34]:

gen_smiles =[]
for g in generated:
    g = g.replace(' ','')
    g = g.replace('<start>','')
    try:
        sm = sf.decoder(g)
        gen_smiles.append(sm)
    except Exception as e:
        print(g,e)



In [35]:
! pip install rdkit-pypi scikit-learn fcd-torch -q

In [38]:
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors, BRICS
from rdkit.Chem.Scaffolds import MurckoScaffold



class MoleculeEvaluator:
    def __init__(self, gen_smiles, ref_smiles, radius=2, n_bits=2048):
        self.gen_smiles = list(gen_smiles)  
        self.ref_smiles = list(ref_smiles)
        self.radius = radius
        self.n_bits = n_bits
        
        self.gen_mols = [Chem.MolFromSmiles(s) for s in self.gen_smiles if Chem.MolFromSmiles(s)]
        self.ref_mols = [Chem.MolFromSmiles(s) for s in self.ref_smiles if Chem.MolFromSmiles(s)]
        
        self.gen_fps = [AllChem.GetMorganFingerprintAsBitVect(m, self.radius, nBits=self.n_bits) for m in self.gen_mols]
        self.ref_fps = [AllChem.GetMorganFingerprintAsBitVect(m, self.radius, nBits=self.n_bits) for m in self.ref_mols]

    # Core Metrics

    def internal_diversity(self):
        """Mean pairwise Tanimoto dissimilarity within generated molecules"""
        n = len(self.gen_fps)
        if n < 2:
            return 0.0
        dists = []
        for i in range(n):
            for j in range(i+1, n):
                sim = DataStructs.TanimotoSimilarity(self.gen_fps[i], self.gen_fps[j])
                dists.append(1 - sim)
        return np.mean(dists)

    def nearest_neighbor_similarity(self):
        """Similarity of each generated molecule to nearest neighbor in reference set"""
        sims = []
        for fp in self.gen_fps:
            sim = max(DataStructs.BulkTanimotoSimilarity(fp, self.ref_fps))
            sims.append(sim)
        return np.mean(sims) if sims else 0.0

    def scaffold_similarity(self):
        """Scaffold overlap between generated and reference sets"""
        def get_scaffolds(mols):
            scaffolds = set()
            for m in mols:
                scaff = MurckoScaffold.MurckoScaffoldSmiles(mol=m)
                scaffolds.add(scaff)
            return scaffolds

        gen_scaff = get_scaffolds(self.gen_mols)
        ref_scaff = get_scaffolds(self.ref_mols)
        inter = len(gen_scaff.intersection(ref_scaff))
        return inter / len(gen_scaff) if gen_scaff else 0.0

    def fragment_similarity(self):
        """Fragment overlap (BRICS decomposition)"""
        def get_fragments(mols):
            frags = []
            for m in mols:
                parts = list(BRICS.BRICSDecompose(m))
                frags.extend(parts)
            return set(frags)

        gen_frags = get_fragments(self.gen_mols)
        ref_frags = get_fragments(self.ref_mols)
        inter = len(gen_frags.intersection(ref_frags))
        return inter / len(gen_frags) if gen_frags else 0.0

    def novelty(self):
        """Fraction of generated molecules not in reference set"""
        ref_set = set(self.ref_smiles)
        novel = [s for s in self.gen_smiles if s not in ref_set]
        return len(novel) / len(self.gen_smiles) if self.gen_smiles else 0.0

    def validity(self):
        """Fraction of valid molecules"""
        return len(self.gen_mols) / len(self.gen_smiles) if self.gen_smiles else 0.0

    def unique_at_k(self, k=None):
        """Fraction of unique molecules (optionally top-k)"""
        smiles = self.gen_smiles[:k] if k else self.gen_smiles
        return len(set(smiles)) / len(smiles) if smiles else 0.0

    def filters(self):
        """Simple Lipinski-like filter"""
        passed = 0
        for m in self.gen_mols:
            mw = Descriptors.MolWt(m)
            logp = Descriptors.MolLogP(m)
            hbd = Descriptors.NumHDonors(m)
            hba = Descriptors.NumHAcceptors(m)
            if (mw < 500 and logp < 5 and hbd <= 5 and hba <= 10):
                passed += 1
        return passed / len(self.gen_mols) if self.gen_mols else 0.0

    # Run all 
    def evaluate_all(self, unique_k=None):
        return {
            "InternalDiversity": self.internal_diversity(),
            "NearestNeighborSimilarity": self.nearest_neighbor_similarity(),
            "ScaffoldSimilarity": self.scaffold_similarity(),
            "FragmentSimilarity": self.fragment_similarity(),
            "Novelty": self.novelty(),
            "Validity": self.validity(),
            "Unique@k": self.unique_at_k(unique_k),
            "FiltersPass": self.filters()
        }



mol_eval = MoleculeEvaluator(gen_smiles=gen_smiles,ref_smiles=train_dataset)

[16:14:33] Explicit valence for atom # 5 Cl, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 6 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 11 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 5 Cl, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 7 Cl, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 1 K, 3, is greater than permitted
[16:14:33] Explicit valence for atom # 12 Cl, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 3 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 4 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 3 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 4 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 1 K, 2, is greater than permitted
[16:14:33] Explicit valence for atom # 3 Na, 2, is greater than permitted
[16:14:33] Explicit valence for atom #

In [39]:
mol_eval.evaluate_all()

{'InternalDiversity': 0.8870384378472012,
 'NearestNeighborSimilarity': 0.10365603236436531,
 'ScaffoldSimilarity': 0.022222222222222223,
 'FragmentSimilarity': 0.03024193548387097,
 'Novelty': 1.0,
 'Validity': 0.957,
 'Unique@k': 0.999,
 'FiltersPass': 0.9811912225705329}