In [1]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2025.3.3-cp310-cp310-manylinux_2_28_x86_64.whl (34.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
Installing collected packages: rdkit
Successfully installed rdkit-2025.3.3
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# Importing Libraries

In [3]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem

E0000 00:00:1750622494.882225      10 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:230


# init TPU (as we are using the Kaggle TPU 4 for faster processing)

In [4]:

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
    print('Not running on TPU, defaulting to GPU/CPU.')
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # Default to GPU or CPU strategy

print(f"Number of accelerators: {strategy.num_replicas_in_sync}")

Not running on TPU, defaulting to GPU/CPU.
Number of accelerators: 1


# Data Loading and preprocesing

In [5]:
SMILES_FILE_PATH = '/kaggle/input/pubchem-smiles-for-pretraining/pubchem_smiles_for_pretraining.txt'

max_nodes_found = 0
num_lines_to_check = 100000 

with open(SMILES_FILE_PATH, 'r') as f:
    for i, line in enumerate(f):
        if i >= num_lines_to_check:
            break
        smiles = line.strip()
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            num_nodes = mol.GetNumAtoms()
            if num_nodes > max_nodes_found:
                max_nodes_found = num_nodes
print(f"Maximum nodes found in the first {num_lines_to_check} molecules: {max_nodes_found}")

def load_smiles_data(file_path, num_samples=None):
    smiles_list = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            if num_samples and i >= num_samples:
                break
            smiles_list.append(line.strip())
    print(f"Loaded {len(smiles_list)} SMILES strings.")
    return smiles_list


# For testing we will use 50,000; 100,000 for quick tests basically.
all_smiles = load_smiles_data(SMILES_FILE_PATH, num_samples=100000) 

# We are Defining the featurization function for the dataset map operation
def featurize_smiles_and_graph(smiles_string):
    token_ids = tokenize_smiles(smiles_string.numpy().decode('utf-8'), char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string.numpy().decode('utf-8'))

    # Handling the cases where graph conversion fails (e.g., invalid smiles)
    if node_features is None:
        dummy_edge_indices = tf.zeros((0, 2), dtype=tf.int32)
        dummy_num_nodes = tf.constant(0, dtype=tf.int32)
        dummy_num_edges = tf.constant(0, dtype=tf.int32) 
        return dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, token_ids, mask 
    
    # We will ensure node_features has consistent shape by padding if necessary for batching
    # It assumes a maximum number of nodes in any graph.
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return (tf.constant(padded_node_features, dtype=tf.float32),
            tf.constant(edge_indices, dtype=tf.int32),
            tf.constant(num_nodes, dtype=tf.int32),
            tf.constant(num_edges, dtype=tf.int32),
            tf.constant(token_ids, dtype=tf.int32),
            tf.constant(mask, dtype=tf.bool))

dataset = tf.data.Dataset.from_tensor_slices(all_smiles)

# Mapping the featurization function by using tf.py_function for non-TF operations (RDKit)
dataset = dataset.map(lambda x: tf.py_function(
    featurize_smiles_and_graph,
    inp=[x],
    Tout=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.bool) 
), num_parallel_calls=tf.data.AUTOTUNE)

# Filtering out the failed conversions
dataset = dataset.filter(lambda node_feat, edge_idx, num_nodes, num_edges, token_ids, mask: num_nodes > 0)

BATCH_SIZE_PER_REPLICA = 64 
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync



Maximum nodes found in the first 100000 molecules: 419
Loaded 100000 SMILES strings.


I0000 00:00:1750622528.830574      10 service.cc:148] XLA service 0x592af79b8810 initialized for platform TPU (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1750622528.830615      10 service.cc:156]   StreamExecutor device (0): TPU, 2a886c8
I0000 00:00:1750622528.830619      10 service.cc:156]   StreamExecutor device (1): TPU, 2a886c8
I0000 00:00:1750622528.830622      10 service.cc:156]   StreamExecutor device (2): TPU, 2a886c8
I0000 00:00:1750622528.830625      10 service.cc:156]   StreamExecutor device (3): TPU, 2a886c8
I0000 00:00:1750622528.830627      10 service.cc:156]   StreamExecutor device (4): TPU, 2a886c8
I0000 00:00:1750622528.830629      10 service.cc:156]   StreamExecutor device (5): TPU, 2a886c8
I0000 00:00:1750622528.830632      10 service.cc:156]   StreamExecutor device (6): TPU, 2a886c8
I0000 00:00:1750622528.830637      10 service.cc:156]   StreamExecutor device (7): TPU, 2a886c8


# --- Creating tf.data.Dataset for TPU ---

In [6]:
# --- SMILES Tokenization ---
# A simple character-level tokenizer for demonstration.
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    # Adding special tokens
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

vocab, char_to_idx, idx_to_char = build_smiles_vocab(all_smiles)
VOCAB_SIZE = len(vocab)
MAX_SMILES_LEN = 256 # Max sequence length for Transformer.
MAX_NODES = max_nodes_found # Maximum number of nodes in any graph in a batch.

def tokenize_smiles(smiles, char_to_idx, max_len):
    """Here we basically plan to convert a SMILES string to a sequence of token IDs."""
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    
    # Pad or truncate
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    """This will create a boolean mask for padded tokens."""
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion ---
# This is the critical part, transforming SMILES to a graph representation which will be usable by a TensorFlow GNN, specifically sparse tensors for efficiency.

# Defining the atom and bond features
# These are just example indices for features, along the way we will design specific one-hot encodings or numerical features based on chemical intuition.
ATOM_FEATURES_LIST = [
    6, 7, 8, 9, 15, 16, 17, 35, 53, # Atomic number (C, N, O, F, P, S, Cl, Br, I)
    1, 2, 3, 4, # Degree (number of bonds)
    0, 1, 2, 3, 4, # Hybridization (SP, SP2, SP3, SP3D, SP3D2 as ints)
    0, 1, # Is aromatic
    -1, 0, 1, 2, 3, 4 # Formal charge
]
NUM_ATOM_FEATURES = len(ATOM_FEATURES_LIST) # we are gonna make it the output dim of our featurizer

BOND_FEATURES_LIST = [
    Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC, # Bond types
    0, 1 # Is conjugated
]
NUM_BOND_FEATURES = len(BOND_FEATURES_LIST) # we will also make this the output dim of our featurizer

padded_shapes = (
    tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), # node_features
    tf.TensorShape([None, 2]), # edge_indices (variable length per graph in batch)
    tf.TensorShape([]),        # num_nodes (scalar per graph)
    tf.TensorShape([]),        # num_edges (scalar per graph) 
    tf.TensorShape([MAX_SMILES_LEN]), # token_ids
    tf.TensorShape([MAX_SMILES_LEN])  # mask
)
padding_values = (
    tf.constant(0.0, dtype=tf.float32), # node_features padding
    tf.constant(0, dtype=tf.int32),     # edge_indices padding
    tf.constant(0, dtype=tf.int32),     # num_nodes padding
    tf.constant(0, dtype=tf.int32),     # num_edges padding 
    tf.constant(char_to_idx['<pad>'], dtype=tf.int32), # token_ids padding
    tf.constant(True, dtype=tf.bool)    # mask padding
)

dataset = dataset.cache() # We are Caching data after featurization for faster epochs
dataset = dataset.shuffle(buffer_size=10000) 
dataset = dataset.padded_batch(GLOBAL_BATCH_SIZE, padded_shapes=padded_shapes, padding_values=padding_values, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

def atom_to_feature_vector(atom):
    """This will convert an RDKit atom to a feature vector."""
    features = []
    # Atomic number (one-hot or direct embedding if using learned features)
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization())) # Convert enum to int
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

def bond_to_feature_vector(bond):
    """converting an RDKit bond to a feature vector."""
    features = []
    features.append(int(bond.GetBondType()))
    features.append(int(bond.GetIsConjugated()))
    return np.array(features, dtype=np.float32)

def smiles_to_tf_graph(smiles_string):
    """
    Converts a SMILES string to TensorFlow graph components:
    node_features, edge_indices, num_nodes, and num_edges.
    """
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        # Return Nones for graph data when RDKit fails to parse SMILES
        # Ensuring all 4 return values are present
        return None, None, None, None 

    # Getting node features
    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features: # Handle empty molecules after atom featurization (should be rare)
        # Ensuring all 4 return values are present
        return None, None, None, None 
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    # Get edge indices (adjacency list format)
    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])  # Adding the reverse edge for undirected graph

    # --- CRITICAL FIXES START HERE ---
    # Use explicit length check for `edge_indices` (which is a Python list at this point)
    if len(edge_indices) == 0: 
        # Handle single atom molecules or molecules with no bonds
        if num_nodes == 1:
            edge_indices_final = np.array([[0, 0]], dtype=np.int32) # Add a self-loop as a NumPy array
            num_edges_final = 1 # Update num_edges for the self-loop
        else:
            # For molecules with >1 node but no bonds (e.g., [C].[C]),
            # or where parsing failed to produce bonds.
            edge_indices_final = tf.zeros((0, 2), dtype=tf.int32) # Return empty TF tensor
            num_edges_final = 0 # No edges
        
        # All return paths now explicitly return all 4 values
        return node_features, edge_indices_final, num_nodes, num_edges_final
    
    # If `edge_indices` was not empty initially, convert it to numpy array
    edge_indices_final = np.array(edge_indices, dtype=np.int32)
    
    # The final `num_edges` should be the length of the *array* of edges being returned.
    num_edges_final = len(edge_indices_final) # Robustly get the count from the final edge array

    return node_features, edge_indices_final, num_nodes, num_edges_final

Built vocabulary of size: 71


# --- Model Architecture (TensorFlow/Keras) ---

In [19]:
# GIN Layer (as a Custom Keras Layer) we are doing it as this implements the core GIN aggregation as a Keras layer suitable for sparse tensors and TPU
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs
        
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)

        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        
        # Sum of neighbor features: (A * H)
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)

        # GIN update: MLP((1 + epsilon) * H + Sum(Neighbors))
        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)

        if self.activation is not None:
            output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.output_dim # (batch_size * max_nodes_per_graph, output_dim)

class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, num_node_features, **kwargs): 
        super(GINEncoder, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim 
        
        self.initial_mlp = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(hidden_dim) 
        ])
        
        self.gin_layers = []
        self.bns = [] # Initialize the list for Batch Normalization layers
        
        # Loop to create GIN layers AND corresponding Batch Normalization layers
        # Each GINLayer will be followed by a BatchNorm.
        for i in range(num_layers): 
            self.gin_layers.append(GINLayer(hidden_dim, activation='relu' if i < num_layers - 1 else None))
            self.bns.append(layers.BatchNormalization())
    
    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs

        x = self.initial_mlp(node_features) # Applying initial transformation
        
        # Correctly iterate through corresponding GIN and BatchNorm layers
        for i in range(len(self.gin_layers)): 
            x = self.gin_layers[i]((x, edge_indices, num_nodes)) # Pass inputs to GINLayer
            x = self.bns[i](x) # Apply BatchNorm AFTER the GINLayer output
        
        batch_size = tf.shape(node_features)[0] // MAX_NODES
        
        # Reshaping to (batch_size, MAX_NODES, hidden_dim)
        # Use the hidden_dim as the last dimension, it's consistent for GIN layers.
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, self.hidden_dim)) 
        
        # Creating a mask for valid nodes in each graph
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32) 
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1) 
        
        # Applying mask and sum
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1) # Sum pooling
        
        return graph_embedding

class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )

        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False, mask=None): # 'mask' parameter is actually `padding_mask` from GRASPModel
        token_ids, padding_mask_bool = inputs # Renamed to padding_mask_bool for clarity
        
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]

        # Create an additive attention mask from the boolean padding_mask
        # MultiHeadAttention expects a mask that is added to the attention scores (logits).
        
        attention_mask_additive = tf.cast(padding_mask_bool, dtype=tf.float32) * -1e9
        
        # Simplest form of additive mask for Keras MHA: (batch_size, 1, key_sequence_length)
        # This will be broadcast across query dimension and heads.
        attention_mask_additive = tf.expand_dims(attention_mask_additive, axis=1) # (batch_size, 1, MAX_SMILES_LEN)
        
        
        for i, (attention, norm1, ff_dense1, ff_dense2, norm2) in enumerate(self.encoder_layers):
            # Attention block
            # Pass the prepared additive mask
            attn_output = attention(x, x, attention_mask=attention_mask_additive, training=training)
            x = norm1(x + attn_output) # Add & Norm

            # Feed-forward block
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output) # Add & Norm
        
        # Global pooling: mask out padded tokens before mean pooling

        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask_bool, axis=-1), dtype=x.dtype)
        
        # Invert the mask logic: 1.0 for non-padded, 0.0 for padded
        non_padded_mask = 1.0 - expanded_padding_mask
        
        # Apply mask: so we will set padded embeddings to 0
        x_masked = x * non_padded_mask
        
        # Sum along sequence dimension
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        
        # Counting non-padded elements per sequence
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        
        # Mean pooling (also avoiding division by zero for fully padded sequences)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9) # Adding epsilon to avoid div by zero
        
        return self.final_norm(smiles_embedding)


class ProjectionHead(keras.Model):
    def __init__(self, input_dim, output_dim, hidden_dim=256, **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.net = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(output_dim)
        ])

    def call(self, x):
        return self.net(x)

class GRASPModel(keras.Model):
    def __init__(self, gin_config, transformer_config, projection_dim, **kwargs):
        super(GRASPModel, self).__init__(**kwargs)
        self.gin_encoder = GINEncoder(**gin_config)
        self.transformer_encoder = TransformerEncoder(**transformer_config)
        
        self.graph_projection_head = ProjectionHead(gin_config['hidden_dim'], projection_dim)
        self.smiles_projection_head = ProjectionHead(transformer_config['embed_dim'], projection_dim)
    
    def call(self, inputs, training=False):
        node_features_padded, edge_indices_padded, num_nodes, num_edges, token_ids, smiles_mask = inputs
        
        node_features_flat = tf.reshape(node_features_padded, (-1, tf.shape(node_features_padded)[2]))
        
        batch_size = tf.shape(node_features_padded)[0] 
        
        # --- Handling Edge Indices for GINLayer ---
        # Creating a mask to identify valid edges (non-padded ones)
        edge_mask = tf.sequence_mask(num_edges, maxlen=tf.shape(edge_indices_padded)[1], dtype=tf.bool)
         
        # Filtering out padded edges, This will bascially flatten the valid edges across the entire batch
        valid_edge_indices = tf.boolean_mask(edge_indices_padded, edge_mask)
        
        # Creating global node offsets for each graph in the batch
        # This transforms local node IDs (0 to MAX_NODES-1) into global IDs across the flattened node list
        node_offsets_for_edges = tf.range(batch_size) * MAX_NODES 
        # Expanding and tiling this offset to apply to each edge
        node_offsets_for_edges = tf.expand_dims(node_offsets_for_edges, axis=1) 
        node_offsets_for_edges_expanded = tf.boolean_mask(tf.tile(node_offsets_for_edges, [1, tf.shape(edge_indices_padded)[1]]), edge_mask)
        node_offsets_for_edges_expanded = tf.expand_dims(node_offsets_for_edges_expanded, axis=-1) 
        
        # Applying offsets to get global edge indices
        global_edge_indices_filtered = valid_edge_indices + tf.cast(node_offsets_for_edges_expanded, dtype=tf.int32)
        
        # Encoding graphs
        graph_embeddings_raw = self.gin_encoder((node_features_flat, global_edge_indices_filtered, num_nodes), training=training)
        graph_embeddings_projected = self.graph_projection_head(graph_embeddings_raw, training=training)
        
        # Encode SMILES
        smiles_embeddings_raw = self.transformer_encoder((token_ids, smiles_mask), training=training)
        smiles_embeddings_projected = self.smiles_projection_head(smiles_embeddings_raw, training=training)
        
        # Apply L2 normalization to projected embeddings for InfoNCE loss
        graph_embeddings_projected = tf.linalg.normalize(graph_embeddings_projected, axis=1)[0]
        smiles_embeddings_projected = tf.linalg.normalize(smiles_embeddings_projected, axis=1)[0]
        
        return graph_embeddings_projected, smiles_embeddings_projected

# --- Contrastive Loss (InfoNCE) ---

In [20]:
class InfoNCELoss(keras.losses.Loss):
    def __init__(self, temperature=0.07, name='info_nce_loss', **kwargs):
        super().__init__(name=name, **kwargs)
        self.temperature = temperature

    @tf.function
    def call(self, graph_embeddings, smiles_embeddings):

        # Cosine similarity matrix (logits)
        # S_ij = sim(g_i, s_j)
        logits = tf.matmul(graph_embeddings, smiles_embeddings, transpose_b=True) / self.temperature
        
        # Creating labels: diagonal elements are positive pairs
        batch_size = tf.shape(logits)[0]
        labels = tf.eye(batch_size) # (batch_size, batch_size)

        # Calculating cross-entropy for graph->SMILES and SMILES->graph
        # loss_g_s: how well graph embeddings predict their corresponding SMILES
        loss_g_s = tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)
        
        # loss_s_g: how well SMILES embeddings predict their corresponding graphs
        loss_s_g = tf.keras.losses.categorical_crossentropy(labels, tf.transpose(logits), from_logits=True)

        # Total loss is the average of both directions
        total_loss = (loss_g_s + loss_s_g) / 2
        
        # Reduce mean over the batch
        return tf.reduce_mean(total_loss)

# --- Training Loop ---

In [None]:
PROJECTION_DIM = 128 # Dimension of the shared embedding space
HIDDEN_DIM_GIN = 256 
NUM_GIN_LAYERS = 3   
EMBED_DIM_TRANSFORMER = 256 
NUM_TRANSFORMER_HEADS = 8
NUM_TRANSFORMER_LAYERS = 3

gin_config = {
    'num_layers': NUM_GIN_LAYERS,
    'hidden_dim': HIDDEN_DIM_GIN,
    # Atom features from featurization. This is the input_dim for the first GIN layer.
    # It must match the output of `atom_to_feature_vector`.
    # 'num_node_features': len(atom_to_feature_vector(Chem.Atom(6))) # Using a dummy atom to get feature count
    'num_node_features': 5 
}

transformer_config = {
    'vocab_size': VOCAB_SIZE,
    'embed_dim': EMBED_DIM_TRANSFORMER,
    'num_heads': NUM_TRANSFORMER_HEADS,
    'num_layers': NUM_TRANSFORMER_LAYERS,
    'max_seq_len': MAX_SMILES_LEN
}

# Defining the training step function
@tf.function
def train_step(inputs):
    with tf.GradientTape() as tape:
        graph_embeddings, smiles_embeddings = model(inputs, training=True)
        loss = info_nce_loss(graph_embeddings, smiles_embeddings)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# Instantiating model, loss, and optimizer within the TPU strategy scope
with strategy.scope():
    model = GRASPModel(gin_config, transformer_config, PROJECTION_DIM)
    info_nce_loss = InfoNCELoss(temperature=0.07) # Adjust temperature as needed
    optimizer = keras.optimizers.Adam(learning_rate=1e-4)

    model.compile(optimizer=optimizer, loss=info_nce_loss)

EPOCHS = 5 
steps_per_epoch = tf.data.experimental.cardinality(dataset).numpy()

print(f"\nStarting pre-training for {EPOCHS} epochs...")
print(f"Steps per epoch: {steps_per_epoch}")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    total_loss = 0.0
    num_batches = 0

    for batch_inputs in dataset:
        per_replica_losses = strategy.run(train_step, args=(batch_inputs,))
        batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        total_loss += batch_loss
        num_batches += 1
        
        if num_batches % 10 == 0:
            print(f"  Batch {num_batches}/{steps_per_epoch}, Loss: {batch_loss:.4f}", end='\r')
    
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}")

print("\nPre-training complete!")


Starting pre-training for 5 epochs...
Steps per epoch: -2

Epoch 1/5




  Batch 40/-2, Loss: 2.3677

# --- Saving the model ---

In [None]:
# To save just the encoders:
# tf.saved_model.save(model.gin_encoder, 'gin_encoder_pretrained')
# tf.saved_model.save(model.transformer_encoder, 'transformer_encoder_pretrained')
# print("Encoders saved.")


model.save('grasp_pretrained_model')
print("Model saved to 'grasp_pretrained_model' directory.")