# Importing Libraries

In [None]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem
from rdkit.Chem.rdmolops import Get={'AllBonds': {'bond_types': [], 'bond_stereo': []}}

# init TPU (as we are using the Kaggle TPU 4 for faster processing)

In [None]:

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    print('Not running on TPU, defaulting to GPU/CPU.')
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # Default to GPU or CPU strategy

print(f"Number of accelerators: {strategy.num_replicas_in_sync}")

# Data Loading and preprocesing

In [None]:
SMILES_FILE_PATH = 'pubchem_smiles_for_pretraining.txt'

def load_smiles_data(file_path, num_samples=None):
    smiles_list = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            if num_samples and i >= num_samples:
                break
            smiles_list.append(line.strip())
    print(f"Loaded {len(smiles_list)} SMILES strings.")
    return smiles_list


# For testing we will use 50,000; 100,000 for quick tests basically.
all_smiles = load_smiles_data(SMILES_FILE_PATH, num_samples=50000) 

# --- SMILES Tokenization ---
# A simple character-level tokenizer for demonstration.
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    # Adding special tokens
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

vocab, char_to_idx, idx_to_char = build_smiles_vocab(all_smiles)
VOCAB_SIZE = len(vocab)
MAX_SMILES_LEN = 256 # Max sequence length for Transformer.

def tokenize_smiles(smiles, char_to_idx, max_len):
    """Here we basically plan to convert a SMILES string to a sequence of token IDs."""
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    
    # Pad or truncate
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    """This will create a boolean mask for padded tokens."""
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion ---
# This is the critical part, transforming SMILES to a graph representation which will be usable by a TensorFlow GNN, specifically sparse tensors for efficiency.

# Defining the atom and bond features
# These are just example indices for features, along the way we will design specific one-hot encodings or numerical features based on chemical intuition.
ATOM_FEATURES_LIST = [
    6, 7, 8, 9, 15, 16, 17, 35, 53, # Atomic number (C, N, O, F, P, S, Cl, Br, I)
    1, 2, 3, 4, # Degree (number of bonds)
    0, 1, 2, 3, 4, # Hybridization (SP, SP2, SP3, SP3D, SP3D2 as ints)
    0, 1, # Is aromatic
    -1, 0, 1, 2, 3, 4 # Formal charge
]
NUM_ATOM_FEATURES = len(ATOM_FEATURES_LIST) # we are gonna make it the output dim of our featurizer

BOND_FEATURES_LIST = [
    Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC, # Bond types
    0, 1 # Is conjugated
]
NUM_BOND_FEATURES = len(BOND_FEATURES_LIST) # we will also make this the output dim of our featurizer


def atom_to_feature_vector(atom):
    """This will convert an RDKit atom to a feature vector."""
    features = []
    # Atomic number (one-hot or direct embedding if using learned features)
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization())) # Convert enum to int
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

def bond_to_feature_vector(bond):
    """converting an RDKit bond to a feature vector."""
    features = []
    features.append(int(bond.GetBondType()))
    features.append(int(bond.GetIsConjugated()))
    return np.array(features, dtype=np.float32)

def smiles_to_tf_graph(smiles_string):
    """
    Converts a SMILES string to TensorFlow graph components:
    node_features, edge_indices, and num_nodes.
    """
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return None, None, None

    # Getting node features
    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features: 
        return None, None, None
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    # Get edge indices (like in the adjacency list format)
    # RDKit's GetBonds() already gives each bond once
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])  # Adding the reverse edge incase for undirected graph

    if not edge_indices: # for handling the single atom molecules or molecules with no bonds
        # For a single node we will create a self-loop so it ks not isolated in the graph
        if num_nodes == 1:
            edge_indices = [[0, 0]]
        else:
            return node_features, tf.zeros((0, 2), dtype=tf.int32), num_nodes # No edges

    edge_indices = np.array(edge_indices, dtype=np.int32)

    return node_features, edge_indices, num_nodes

# --- Creating tf.data.Dataset for TPU ---

In [None]:

# We are Defining the featurization function for the dataset map operation
def featurize_smiles_and_graph(smiles_string):
    token_ids = tokenize_smiles(smiles_string.numpy().decode('utf-8'), char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes = smiles_to_tf_graph(smiles_string.numpy().decode('utf-8'))

    # Handling the cases where graph conversion fails (e.g., invalid smiles)
    if node_features is None:
        # Return dummy values that will be filtered out later if needed, or ideally, pre-filter your SMILES list to remove unconvertible ones.
        # For simplicity here, we'll return a minimal valid structure for filtering.
        dummy_node_features = tf.zeros((1, NUM_ATOM_FEATURES), dtype=tf.float32)
        dummy_edge_indices = tf.zeros((0, 2), dtype=tf.int32)
        dummy_num_nodes = tf.constant(0, dtype=tf.int32)
        return dummy_node_features, dummy_edge_indices, dummy_num_nodes, token_ids, mask
    
    # We will ensure node_features has consistent shape by padding if necessary for batching
    # It assumes a maximum number of nodes in any graph.
    MAX_NODES = 100 
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return (tf.constant(padded_node_features, dtype=tf.float32),
            tf.constant(edge_indices, dtype=tf.int32),
            tf.constant(num_nodes, dtype=tf.int32),
            tf.constant(token_ids, dtype=tf.int32),
            tf.constant(mask, dtype=tf.bool))

dataset = tf.data.Dataset.from_tensor_slices(all_smiles)

# Mapping the featurization function by using tf.py_function for non-TF operations (RDKit)
dataset = dataset.map(lambda x: tf.py_function(
    featurize_smiles_and_graph,
    inp=[x],
    Tout=(tf.float32, tf.int32, tf.int32, tf.int32, tf.bool)
), num_parallel_calls=tf.data.AUTOTUNE)

# Filtering out the failed conversions
dataset = dataset.filter(lambda node_feat, edge_idx, num_nodes, token_ids, mask: num_nodes > 0)

BATCH_SIZE_PER_REPLICA = 64 
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

padded_shapes = (
    tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), # node_features
    tf.TensorShape([None, 2]), # edge_indices (variable length per graph in batch)
    tf.TensorShape([]),        # num_nodes (scalar per graph)
    tf.TensorShape([MAX_SMILES_LEN]), # token_ids
    tf.TensorShape([MAX_SMILES_LEN])  # mask
)
padding_values = (
    tf.constant(0.0, dtype=tf.float32), # node_features padding
    tf.constant(0, dtype=tf.int32),     # edge_indices padding
    tf.constant(0, dtype=tf.int32),     # num_nodes padding (not strictly needed as it's a scalar per element)
    tf.constant(char_to_idx['<pad>'], dtype=tf.int32), # token_ids padding
    tf.constant(True, dtype=tf.bool)    # mask padding
)

dataset = dataset.cache() # We are Caching data after featurization for faster epochs
dataset = dataset.shuffle(buffer_size=10000) 
dataset = dataset.padded_batch(GLOBAL_BATCH_SIZE, padded_shapes=padded_shapes, padding_values=padding_values, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

# --- Model Architecture (TensorFlow/Keras) ---

In [None]:
# GIN Layer (as a Custom Keras Layer) we are doing it as this implements the core GIN aggregation as a Keras layer suitable for sparse tensors and TPU
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs
        
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)

        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        
        # Sum of neighbor features: (A * H)
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)

        # GIN update: MLP((1 + epsilon) * H + Sum(Neighbors))
        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)

        if self.activation is not None:
            output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.output_dim # (batch_size * max_nodes_per_graph, output_dim)

class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, **kwargs):
        super(GINEncoder, self).__init__(**kwargs)
        self.gin_layers = []
        for i in range(num_layers):
            self.gin_layers.append(GINLayer(hidden_dim, activation='relu' if i < num_layers - 1 else None))
            self.gin_layers.append(layers.BatchNormalization()) # we are adding BatchNorm

    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs # num_nodes is list of actual counts per graph

        x = node_features
        for i, gin_layer in enumerate(self.gin_layers):
            if isinstance(gin_layer, GINLayer):
                x = gin_layer((x, edge_indices, num_nodes))
            else: # in the case ofBatchNormalization
                x = gin_layer(x)
        
        # Global pooling (Sum Pooling for GIN)
        batch_size = tf.shape(node_features)[0] // MAX_NODES 
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, hidden_dim))
        
        # Mask out padded nodes before summing
       
        # For now, let's use a dummy global pooling, we are gonna refine it later.
        
        batch_size = tf.shape(node_features)[0] // MAX_NODES # This assumes exact padding
        
        # Reshape to (batch_size, MAX_NODES, hidden_dim)
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, hidden_dim))
        
        # Creating a mask for valid nodes in each graph
        # `num_nodes` is the actual number of nodes *before* padding for each graph in the batch (GLOBAL_BATCH_SIZE,)
        
        # We use `tf.sequence_mask` to create a boolean mask
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32) # (BATCH_SIZE, MAX_NODES)
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1) # (BATCH_SIZE, MAX_NODES, 1)
        
        # Apply mask and sum
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1) # Sum pooling
        
        return graph_embedding 

class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )

        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False, mask=None):
        token_ids, padding_mask = inputs # padding_mask is boolean (True for padded)
        
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]

        for i, (attention, norm1, ff_dense1, ff_dense2, norm2) in enumerate(self.encoder_layers):
            # Attention block
            attn_output = attention(x, x, attention_mask=padding_mask, training=training)
            x = norm1(x + attn_output) # Add & Norm

            # Feed-forward block
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output) # Add & Norm
        
        # Global pooling: mask out padded tokens before mean pooling

        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask, axis=-1), dtype=x.dtype)
        
        # Invert the mask logic: 1.0 for non-padded, 0.0 for padded
        non_padded_mask = 1.0 - expanded_padding_mask
        
        # Apply mask: so we will set padded embeddings to 0
        x_masked = x * non_padded_mask
        
        # Sum along sequence dimension
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        
        # Counting non-padded elements per sequence
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        
        # Mean pooling (also avoiding division by zero for fully padded sequences)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9) # Adding epsilon to avoid div by zero
        
        return self.final_norm(smiles_embedding)


class ProjectionHead(keras.Model):
    def __init__(self, input_dim, output_dim, hidden_dim=256, **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.net = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(output_dim)
        ])

    def call(self, x):
        return self.net(x)

class GRASPModel(keras.Model):
    def __init__(self, gin_config, transformer_config, projection_dim, **kwargs):
        super(GRASPModel, self).__init__(**kwargs)
        self.gin_encoder = GINEncoder(**gin_config)
        self.transformer_encoder = TransformerEncoder(**transformer_config)
        
        self.graph_projection_head = ProjectionHead(gin_config['hidden_dim'], projection_dim)
        self.smiles_projection_head = ProjectionHead(transformer_config['embed_dim'], projection_dim)

    def call(self, inputs, training=False):
        # Inputs are unpacked from the dataset tuple
        node_features, edge_indices, num_nodes, token_ids, smiles_mask = inputs
        
        # Encoding graphs
        graph_embeddings_raw = self.gin_encoder((node_features, edge_indices, num_nodes), training=training)
        graph_embeddings_projected = self.graph_projection_head(graph_embeddings_raw, training=training)

        # Encoding SMILES
        smiles_embeddings_raw = self.transformer_encoder((token_ids, smiles_mask), training=training)
        smiles_embeddings_projected = self.smiles_projection_head(smiles_embeddings_raw, training=training)

        # Applying L2 normalization to projected embeddings for InfoNCE loss
        graph_embeddings_projected = tf.linalg.normalize(graph_embeddings_projected, axis=1)[0]
        smiles_embeddings_projected = tf.linalg.normalize(smiles_embeddings_projected, axis=1)[0]

        return graph_embeddings_projected, smiles_embeddings_projected

# --- Contrastive Loss (InfoNCE) ---

In [None]:
class InfoNCELoss(keras.losses.Loss):
    def __init__(self, temperature=0.07, name='info_nce_loss', **kwargs):
        super().__init__(name=name, **kwargs)
        self.temperature = temperature

    @tf.function
    def call(self, graph_embeddings, smiles_embeddings):

        # Cosine similarity matrix (logits)
        # S_ij = sim(g_i, s_j)
        logits = tf.matmul(graph_embeddings, smiles_embeddings, transpose_b=True) / self.temperature
        
        # Creating labels: diagonal elements are positive pairs
        batch_size = tf.shape(logits)[0]
        labels = tf.eye(batch_size) # (batch_size, batch_size)

        # Calculating cross-entropy for graph->SMILES and SMILES->graph
        # loss_g_s: how well graph embeddings predict their corresponding SMILES
        loss_g_s = tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)
        
        # loss_s_g: how well SMILES embeddings predict their corresponding graphs
        loss_s_g = tf.keras.losses.categorical_crossentropy(labels, tf.transpose(logits), from_logits=True)

        # Total loss is the average of both directions
        total_loss = (loss_g_s + loss_s_g) / 2
        
        # Reduce mean over the batch
        return tf.reduce_mean(total_loss)

# --- Training Loop ---

In [None]:
PROJECTION_DIM = 128 # Dimension of the shared embedding space
HIDDEN_DIM_GIN = 256 
NUM_GIN_LAYERS = 3   
EMBED_DIM_TRANSFORMER = 256 
NUM_TRANSFORMER_HEADS = 8
NUM_TRANSFORMER_LAYERS = 3

gin_config = {
    'num_layers': NUM_GIN_LAYERS,
    'hidden_dim': HIDDEN_DIM_GIN,
    # Atom features from featurization. This is the input_dim for the first GIN layer.
    # It must match the output of `atom_to_feature_vector`.
    'num_node_features': len(atom_to_feature_vector(Chem.Atom(6))) # Using a dummy atom to get feature count
}

transformer_config = {
    'vocab_size': VOCAB_SIZE,
    'embed_dim': EMBED_DIM_TRANSFORMER,
    'num_heads': NUM_TRANSFORMER_HEADS,
    'num_layers': NUM_TRANSFORMER_LAYERS,
    'max_seq_len': MAX_SMILES_LEN
}

# Defining the training step function
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        graph_embeddings, smiles_embeddings = model(inputs, training=True)
        loss = info_nce_loss(graph_embeddings, smiles_embeddings)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Instantiating model, loss, and optimizer within the TPU strategy scope
with strategy.scope():
    model = GRASPModel(gin_config, transformer_config, PROJECTION_DIM)
    info_nce_loss = InfoNCELoss(temperature=0.07) # Adjust temperature as needed
    optimizer = keras.optimizers.Adam(learning_rate=1e-4)

    model.compile(optimizer=optimizer, loss=info_nce_loss)

EPOCHS = 5 
steps_per_epoch = tf.data.experimental.cardinality(dataset).numpy()

print(f"\nStarting pre-training for {EPOCHS} epochs...")
print(f"Steps per epoch: {steps_per_epoch}")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    total_loss = 0.0
    num_batches = 0

    for batch_inputs in dataset:
        per_replica_losses = strategy.run(train_step, args=(batch_inputs,))
        batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        total_loss += batch_loss
        num_batches += 1
        
        if num_batches % 10 == 0:
            print(f"  Batch {num_batches}/{steps_per_epoch}, Loss: {batch_loss:.4f}", end='\r')
    
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}")

print("\nPre-training complete!")

# --- Saving the model ---

In [None]:
# To save just the encoders:
# tf.saved_model.save(model.gin_encoder, 'gin_encoder_pretrained')
# tf.saved_model.save(model.transformer_encoder, 'transformer_encoder_pretrained')
# print("Encoders saved.")


model.save('grasp_pretrained_model')
print("Model saved to 'grasp_pretrained_model' directory.")