In [1]:
# %%
!pip install rdkit --quiet
!pip install deepchem --quiet
!pip install tqdm --quiet

In [2]:
# %%
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem
from tqdm import tqdm # For progress bars

# Suppress non-critical RDKit warnings
from rdkit import rdBase
rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

2025-07-01 18:56:31.056985: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751396191.081378     260 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751396191.088792     260 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
print("Checking for Metal GPU support on M1/Apple Silicon...")
try:
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print(f"Found {len(gpus)} GPU(s) available.")
        strategy = tf.distribute.OneDeviceStrategy(device="/GPU:0") 
        print(f"Using Metal GPU strategy on: {gpus[0].name}")
    else:
        print("No GPU (Metal) device found, defaulting to CPU.")
        strategy = tf.distribute.get_strategy() # Default to CPU strategy
except Exception as e:
    print(f"Error during GPU detection/setup: {e}")
    print("Defaulting to CPU.")
    strategy = tf.distribute.get_strategy()

print(f"Number of accelerators: {strategy.num_replicas_in_sync}")

Checking for Metal GPU support on M1/Apple Silicon...
Found 1 GPU(s) available.
Using Metal GPU strategy on: /physical_device:GPU:0
Number of accelerators: 1


I0000 00:00:1751396195.287462     260 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


In [4]:
# --- IMPORTANT: Update PUBCHEM_TFRECORDS_DIR to your local path ---
# This should point to the directory containing your pubchem_shard_XXX.tfrecord files.
PUBCHEM_TFRECORDS_DIR = '/kaggle/input/pubchem-tfrecords-1m/pubchem_tfrecords_1M' # Directory created by preprocess_pubchem.py

# --- Data Parameters (MUST match TFRecord creation) ---
MAX_SMILES_LEN = 256 
MAX_NODES = 419 # This MUST be the MAX_NODES used when creating the TFRecords
NUM_ATOM_FEATURES = 5 # This MUST be the NUM_ATOM_FEATURES used when creating the TFRecords

# --- Model Hyperparameters ---
PROJECTION_DIM = 128 
HIDDEN_DIM_GIN = 256 
NUM_GIN_LAYERS = 3   
EMBED_DIM_TRANSFORMER = 256 
NUM_TRANSFORMER_HEADS = 8
NUM_TRANSFORMER_LAYERS = 3

# --- Training Parameters ---
BATCH_SIZE_PER_REPLICA = 64 # Batch size per GPU. Adjust based on M1 memory.
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync # For 1 device, this is BATCH_SIZE_PER_REPLICA
LEARNING_RATE = 1e-4
TEMPERATURE = 0.07 # Temperature for InfoNCE Loss
EPOCHS = 5 # Number of pre-training epochs.

# --- Checkpointing and Resuming ---
CHECKPOINT_DIR = 'pretraining_checkpoints' 
RESUME_TRAINING = False # Set to True to resume from the best saved model
START_EPOCH = 0 # If resuming, set this to the epoch you want to start from (e.g., 3 to start epoch 3)

In [5]:
# --- SMILES Tokenization ---
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

def tokenize_smiles(smiles, char_to_idx, max_len):
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion ---
def atom_to_feature_vector(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization()))
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

# NUM_ATOM_FEATURES defined after atom_to_feature_vector is available
NUM_ATOM_FEATURES = len(atom_to_feature_vector(Chem.Atom(6))) 

def smiles_to_tf_graph(smiles_string):
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)

    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])

    num_edges = len(edge_indices)
    if num_edges == 0:
        if num_nodes > 0:
            edge_indices_final = np.empty((0, 2), dtype=np.int32)
            num_edges_final = 0
        else:
            return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                    np.zeros((0, 2), dtype=np.int32),
                    0,
                    0)
    else:
        edge_indices_final = np.array(edge_indices, dtype=np.int32)
        num_edges_final = len(edge_indices_final)

    return node_features, edge_indices_final, num_nodes, num_edges_final


# --- TFRecord Parsing Function ---
def parse_tfrecord_example_pretraining(example_proto):
    feature_description = {
        'node_feat_padded': tf.io.FixedLenFeature([], tf.string),
        'edge_idx': tf.io.FixedLenFeature([], tf.string),
        'num_nodes': tf.io.FixedLenFeature([], tf.int64),
        'num_edges': tf.io.FixedLenFeature([], tf.int64),
        'token_ids': tf.io.FixedLenFeature([], tf.string),
        'smiles_mask': tf.io.FixedLenFeature([], tf.string),
    }

    parsed_features = tf.io.parse_single_example(example_proto, feature_description)

    # Decode raw bytes back to tensors and reshape
    node_feat_padded = tf.io.decode_raw(parsed_features['node_feat_padded'], tf.float32)
    node_feat_padded = tf.reshape(node_feat_padded, [MAX_NODES, NUM_ATOM_FEATURES])

    edge_idx = tf.io.decode_raw(parsed_features['edge_idx'], tf.int32)
    edge_idx = tf.reshape(edge_idx, [-1, 2]) # -1 infers the first dimension (number of edges)

    num_nodes = tf.cast(parsed_features['num_nodes'], tf.int32)
    num_edges = tf.cast(parsed_features['num_edges'], tf.int32)

    token_ids = tf.io.decode_raw(parsed_features['token_ids'], tf.int32)
    token_ids = tf.reshape(token_ids, [MAX_SMILES_LEN])

    smiles_mask = tf.io.decode_raw(parsed_features['smiles_mask'], tf.bool)
    smiles_mask = tf.reshape(smiles_mask, [MAX_SMILES_LEN])

    # Return the 6 components as a tuple matching the model's input structure
    return (node_feat_padded, edge_idx, num_nodes, num_edges, token_ids, smiles_mask)

In [6]:
print("Building vocabulary for pre-training data...")
dummy_smiles_for_vocab_build = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "c", "n", "=", "#", "(", ")", "[", "]", "@", "+", "-", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "H", "B", "b", "K", "k", "L", "l", "M", "m", "R", "r", "X", "x", "Y", "y", "Z", "z"] 
_, char_to_idx, _ = build_smiles_vocab(dummy_smiles_for_vocab_build)
VOCAB_SIZE = len(char_to_idx) 
print(f"Vocabulary built with {VOCAB_SIZE} tokens for pre-training.")


# %%
# Define padded_shapes for the output of featurize_smiles_and_graph
padded_shapes = (
    tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), # node_features
    tf.TensorShape([None, 2]),                     # edge_indices (variable length per graph in batch)
    tf.TensorShape([]),                           # num_nodes
    tf.TensorShape([]),                           # num_edges
    tf.TensorShape([MAX_SMILES_LEN]),              # token_ids
    tf.TensorShape([MAX_SMILES_LEN])               # smiles_mask
)
padding_values = (
    tf.constant(0.0, dtype=tf.float32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(char_to_idx['<pad>'], dtype=tf.int32),
    tf.constant(False, dtype=tf.bool) 
)

# Function to load pre-processed TFRecord dataset
def load_tfrecord_pretraining_dataset(tfrecord_dir, batch_size, shuffle_buffer_size=100000):
    # List all TFRecord files
    tfrecord_files = tf.io.gfile.glob(os.path.join(tfrecord_dir, '*.tfrecord'))
    print(f"Found {len(tfrecord_files)} TFRecord shards in {tfrecord_dir}")
    if not tfrecord_files:
        raise FileNotFoundError(f"No TFRecord files found in {tfrecord_dir}. Please run preprocess_pubchem.py first.")
    
    # Create dataset from TFRecord files
    dataset = tf.data.TFRecordDataset(tfrecord_files)
    
    # Pass the explicit output_signature to the map function
    # This ensures TF knows the exact structure and types returned by the parsing function.
    parse_output_signature = (
        tf.TensorSpec(shape=[MAX_NODES, NUM_ATOM_FEATURES], dtype=tf.float32), # node_feat_padded
        tf.TensorSpec(shape=[None, 2], dtype=tf.int32),                     # edge_idx
        tf.TensorSpec(shape=[], dtype=tf.int32),                             # num_nodes
        tf.TensorSpec(shape=[], dtype=tf.int32),                             # num_edges
        tf.TensorSpec(shape=[MAX_SMILES_LEN], dtype=tf.int32),                # token_ids
        tf.TensorSpec(shape=[MAX_SMILES_LEN], dtype=tf.bool)                 # smiles_mask
    )
    dataset = dataset.map(parse_tfrecord_example_pretraining, 
                          num_parallel_calls=tf.data.AUTOTUNE) 
    
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    
    dataset = dataset.padded_batch(batch_size, 
                                   padded_shapes=padded_shapes, # Use the globally defined padded_shapes
                                   padding_values=padding_values, # Use the globally defined padding_values
                                   drop_remainder=True)
    dataset = dataset.repeat() # Repeat indefinitely for pre-training
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# Load the pre-processed TFRecord dataset
dataset = load_tfrecord_pretraining_dataset(PUBCHEM_TFRECORDS_DIR, GLOBAL_BATCH_SIZE)


Building vocabulary for pre-training data...
Built vocabulary of size: 49
Vocabulary built with 49 tokens for pre-training.
Found 100 TFRecord shards in /kaggle/input/pubchem-tfrecords-1m/pubchem_tfrecords_1M


In [7]:
# GIN Layer
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs # num_nodes_batch is unused here, but passed
        
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)

        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)

        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)

        if self.activation is not None:
            output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.output_dim

# GIN Encoder
class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, num_node_features, **kwargs): 
        super(GINEncoder, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim 
        self.initial_mlp = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(hidden_dim) 
        ])
        self.gin_layers = []
        self.bns = []
        for i in range(num_layers): 
            self.gin_layers.append(GINLayer(hidden_dim, activation='relu' if i < num_layers - 1 else None))
            self.bns.append(layers.BatchNormalization())
    
    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs
        x = self.initial_mlp(node_features)
        for i in range(len(self.gin_layers)): 
            x = self.gin_layers[i]((x, edge_indices, num_nodes))
            x = self.bns[i](x)
        
        batch_size = tf.shape(node_features)[0] // MAX_NODES
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, self.hidden_dim)) 
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32) 
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1) 
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1)
        return graph_embedding

# Transformer Encoder
class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )
        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False, mask=None):
        token_ids, padding_mask_bool = inputs 
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]
        attention_mask_additive = tf.cast(padding_mask_bool, dtype=tf.float32) * -1e9
        attention_mask_additive = tf.expand_dims(attention_mask_additive, axis=1) 
        
        for i, (attention, norm1, ff_dense1, ff_dense2, norm2) in enumerate(self.encoder_layers):
            attn_output = attention(x, x, attention_mask=attention_mask_additive, training=training)
            x = norm1(x + attn_output)
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output)
        
        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask_bool, axis=-1), dtype=x.dtype)
        non_padded_mask = 1.0 - expanded_padding_mask
        x_masked = x * non_padded_mask
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9)
        
        return self.final_norm(smiles_embedding)

# Projection Head
class ProjectionHead(keras.Model):
    def __init__(self, input_dim, output_dim, hidden_dim=256, **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.net = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(output_dim)
        ])

    def call(self, x):
        return self.net(x)

# GRASP Model (Dual-Brain Architecture)
class GRASPModel(keras.Model):
    def __init__(self, gin_config, transformer_config, projection_dim, **kwargs):
        super(GRASPModel, self).__init__(**kwargs)
        self.gin_encoder = GINEncoder(**gin_config)
        self.transformer_encoder = TransformerEncoder(**transformer_config)
        
        self.graph_projection_head = ProjectionHead(gin_config['hidden_dim'], projection_dim)
        self.smiles_projection_head = ProjectionHead(transformer_config['embed_dim'], projection_dim)
    
    def call(self, inputs, training=False):
        # Inputs are unpacked from the dataset tuple
        node_features_padded, edge_indices_padded, num_nodes, num_edges, token_ids, smiles_mask = inputs
        
        node_features_flat = tf.reshape(node_features_padded, (-1, tf.shape(node_features_padded)[2]))
        
        batch_size = tf.shape(node_features_padded)[0] 
        
        # --- Handling Edge Indices for GINLayer ---
        edge_mask = tf.sequence_mask(num_edges, maxlen=tf.shape(edge_indices_padded)[1], dtype=tf.bool)
        valid_edge_indices = tf.cast(tf.boolean_mask(edge_indices_padded, edge_mask), dtype=tf.int32) 
        batch_ids_for_edges = tf.cast(tf.where(edge_mask)[:, 0], dtype=tf.int32)
        node_offsets_for_edges = tf.range(batch_size) * MAX_NODES
        offsets = tf.gather(node_offsets_for_edges, batch_ids_for_edges)
        offsets = tf.expand_dims(offsets, axis=-1)
        global_edge_indices_filtered = valid_edge_indices + offsets
        
        # Get embeddings from pre-trained encoders
        graph_embeddings_raw = self.gin_encoder((node_features_flat, global_edge_indices_filtered, num_nodes), training=training)
        smiles_embeddings_raw = self.transformer_encoder((token_ids, smiles_mask), training=training)
        
        # Apply L2 normalization to projected embeddings for InfoNCE loss
        graph_embeddings_projected = self.graph_projection_head(graph_embeddings_raw)
        smiles_embeddings_projected = self.smiles_projection_head(smiles_embeddings_raw)
        
        graph_embeddings_projected = tf.linalg.normalize(graph_embeddings_projected, axis=1)[0]
        smiles_embeddings_projected = tf.linalg.normalize(smiles_embeddings_projected, axis=1)[0]
        
        return graph_embeddings_projected, smiles_embeddings_projected

In [8]:
class InfoNCELoss(keras.losses.Loss):
    def __init__(self, temperature=0.07, name='info_nce_loss', **kwargs):
        super().__init__(name=name, **kwargs)
        self.temperature = temperature

    @tf.function
    def call(self, graph_embeddings, smiles_embeddings):
        logits = tf.matmul(graph_embeddings, smiles_embeddings, transpose_b=True) / self.temperature
        batch_size = tf.shape(logits)[0]
        labels = tf.eye(batch_size)
        loss_g_s = tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)
        loss_s_g = tf.keras.losses.categorical_crossentropy(labels, tf.transpose(logits), from_logits=True)
        total_loss = (loss_g_s + loss_s_g) / 2
        return tf.reduce_mean(total_loss)

In [None]:
with strategy.scope(): # Model and optimizer instantiation within strategy scope
    gin_config = {
        'num_layers': NUM_GIN_LAYERS,
        'hidden_dim': HIDDEN_DIM_GIN,
        'num_node_features': NUM_ATOM_FEATURES
    }

    transformer_config = {
        'vocab_size': VOCAB_SIZE,
        'embed_dim': EMBED_DIM_TRANSFORMER,
        'num_heads': NUM_TRANSFORMER_HEADS,
        'num_layers': NUM_TRANSFORMER_LAYERS,
        'max_seq_len': MAX_SMILES_LEN
    }

    # Create checkpoint directory if it doesn't exist
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    model = GRASPModel(gin_config, transformer_config, PROJECTION_DIM)
    info_nce_loss = InfoNCELoss(temperature=TEMPERATURE)
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)

    # Load checkpoint if resuming training
    if RESUME_TRAINING:
        try:
            print(f"Attempting to resume training from {CHECKPOINT_DIR}/gin_encoder_best and {CHECKPOINT_DIR}/transformer_encoder_best")
            loaded_gin_encoder = tf.saved_model.load(os.path.join(CHECKPOINT_DIR, 'gin_encoder_best'))
            loaded_transformer_encoder = tf.saved_model.load(os.path.join(CHECKPOINT_DIR, 'transformer_encoder_best'))
            
            # Assign loaded encoders to the model's attributes
            model.gin_encoder = loaded_gin_encoder
            model.transformer_encoder = loaded_transformer_encoder
            print("Model loaded successfully for resuming.")
            # Note: Optimizer state is NOT restored here. Learning rate schedule will restart.
        except Exception as e:
            print(f"Could not load checkpoint for resuming: {e}. Starting training from scratch.")
            RESUME_TRAINING = False # Reset flag if loading fails

    # The train_step function performs a single gradient update.
    @tf.function(reduce_retracing=True)
    def train_step(inputs):
        with tf.GradientTape() as tape:
            graph_embeddings, smiles_embeddings = model(inputs, training=True)
            loss = info_nce_loss(graph_embeddings, smiles_embeddings)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss

# Determine the number of steps per epoch for tqdm
# This should be calculated based on the total number of samples in the TFRecords
# For 1 million samples, if each shard has ~100MB, and each sample is ~10KB,
# 1M samples * 10KB/sample = 10GB total.
# If GLOBAL_BATCH_SIZE is 64, then 1,000,000 / 64 = 15625 steps per epoch.
# We need to get the exact count from the TFRecords.
# For now, estimate based on the target 1M samples.
total_samples_in_tfrecords = 1_000_000 # This should be the actual count from your preprocessing
steps_per_epoch_tqdm = total_samples_in_tfrecords // GLOBAL_BATCH_SIZE
steps_per_epoch_tqdm = max(1, steps_per_epoch_tqdm)

print(f"\nStarting pre-training for {EPOCHS} epochs...")
print(f"Dataset has {steps_per_epoch_tqdm} batches per epoch.")

best_val_loss = float('inf') # Initialize best loss for saving best model

for epoch in range(START_EPOCH, EPOCHS): # Use START_EPOCH for resuming
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    total_loss = 0.0
    
    epoch_iterator = tqdm(dataset, desc=f"Epoch {epoch + 1} Training", total=steps_per_epoch_tqdm, leave=True)
    
    num_batches = 0 

    for i, batch_inputs in enumerate(epoch_iterator):
        per_replica_losses = strategy.run(train_step, args=(batch_inputs,))
        # Reduce sum across replicas (will be just the loss itself for OneDeviceStrategy)
        batch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        
        total_loss += batch_loss
        num_batches += 1
        
        epoch_iterator.set_postfix(loss=f"{batch_loss.numpy():.4f}")
    
    if num_batches > 0: 
        avg_loss = total_loss / num_batches
        print(f"\nEpoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}")

        # --- Checkpointing Logic ---
        current_gin_path = os.path.join(CHECKPOINT_DIR, f'gin_encoder_epoch_{epoch+1}')
        current_transformer_path = os.path.join(CHECKPOINT_DIR, f'transformer_encoder_epoch_{epoch+1}')
        
        tf.saved_model.save(model.gin_encoder, current_gin_path)
        tf.saved_model.save(model.transformer_encoder, current_transformer_path)
        print(f"Epoch {epoch+1} encoders saved to {CHECKPOINT_DIR}")

        if avg_loss < best_val_loss:
            best_val_loss = avg_loss
            best_gin_path = os.path.join(CHECKPOINT_DIR, 'gin_encoder_best')
            best_transformer_path = os.path.join(CHECKPOINT_DIR, 'transformer_encoder_best')
            
            tf.saved_model.save(model.gin_encoder, best_gin_path)
            tf.saved_model.save(model.transformer_encoder, best_transformer_path)
            print(f"New best model saved at epoch {epoch+1} with loss {best_val_loss:.4f}")

    else:
        print(f"\nEpoch {epoch + 1} finished. No batches processed (check dataset size/filtering).\n") # Added newline for clarity

print("\nPre-training complete!")


Starting pre-training for 5 epochs...
Dataset has 15625 batches per epoch.

Epoch 1/5


Epoch 1 Training:  26%|██▋       | 4138/15625 [09:52<25:57,  7.38it/s, loss=0.2015] 

In [None]:
# Final save of the encoders (redundant if best model is saved on last epoch, but ensures last state is saved)
tf.saved_model.save(model.gin_encoder, 'gin_encoder_final')
tf.saved_model.save(model.transformer_encoder, 'transformer_encoder_final')
print("Final Encoders saved.")

# Optionally, save the entire pre-trained GRASP model as well
model.export('grasp_pretrained_model_tf_savedmodel')
print("Full Model saved to 'grasp_pretrained_model_tf_savedmodel' directory.")

In [None]:
print("\n--- Starting Post-Pretraining Qualitative Evaluation ---")

try:
    if 'model' not in locals() or not hasattr(model, 'gin_encoder') or not hasattr(model, 'transformer_encoder'):
        print("Loading encoders from disk...")
        loaded_gin_encoder = tf.saved_model.load('gin_encoder_final') # Load final saved encoders
        loaded_transformer_encoder = tf.saved_model.load('transformer_encoder_final')
    else:
        print("Using encoders directly from trained model in memory.")
        loaded_gin_encoder = model.gin_encoder
        loaded_transformer_encoder = model.transformer_encoder

except Exception as e:
    print(f"Error loading encoders: {e}. Make sure they were saved correctly and paths are valid.")
    print("Skipping qualitative evaluation.")
    import sys
    sys.exit() 

# Helper function to get embeddings for a single SMILES string
def get_single_molecule_embeddings(smiles_string, gin_encoder, transformer_encoder):
    token_ids = tokenize_smiles(smiles_string, char_to_idx, MAX_SMILES_LEN)
    smiles_mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string)

    if num_nodes == 0: 
        print(f"Warning: Could not featurize SMILES '{smiles_string}'. Skipping.")
        return None, None
    
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    node_features_for_gin = tf.reshape(padded_node_features, (1 * MAX_NODES, NUM_ATOM_FEATURES)) 
    edge_indices_for_gin = tf.cast(edge_indices, dtype=tf.int32) 
    num_nodes_for_gin = tf.constant([num_nodes], dtype=tf.int32)
    
    token_ids_for_transformer = tf.expand_dims(token_ids, axis=0) 
    smiles_mask_for_transformer = tf.expand_dims(smiles_mask, axis=0) 
    
    graph_embedding = gin_encoder((node_features_for_gin, edge_indices_for_gin, num_nodes_for_gin), training=False)
    smiles_embedding = transformer_encoder((token_ids_for_transformer, smiles_mask_for_transformer), training=False)
    
    graph_embedding_normalized = tf.linalg.normalize(graph_embedding, axis=1)[0]
    smiles_embedding_normalized = tf.linalg.normalize(smiles_embedding, axis=1)[0]

    return graph_embedding_normalized, smiles_embedding_normalized


# --- Test Cases ---
test_smiles = [
    "CCO",      # Ethanol (small, simple alcohol)
    "c1ccccc1", # Benzene (aromatic ring)
    "O=C(Cl)c1ccccc1", # Benzoyl chloride (more complex, functional group)
    "CC(=O)Oc1ccccc1C(=O)O", # Aspirin (larger, common drug)
    "C" # Methane (single atom, check edge case)
]

all_graph_embeddings = []
all_smiles_embeddings = []
valid_smiles_for_eval = []

for smiles in test_smiles:
    g_embed, s_embed = get_single_molecule_embeddings(smiles, loaded_gin_encoder, loaded_transformer_encoder)
    if g_embed is not None and s_embed is not None:
        all_graph_embeddings.append(g_embed)
        all_smiles_embeddings.append(s_embed)
        valid_smiles_for_eval.append(smiles)
    
if not valid_smiles_for_eval:
    print("No valid SMILES were processed for qualitative evaluation.")
else:
    all_graph_embeddings_tensor = tf.concat(all_graph_embeddings, axis=0) 
    all_smiles_embeddings_tensor = tf.concat(all_smiles_embeddings, axis=0) 

    print("\n--- Cosine Similarity Matrix (Graph vs. SMILES) ---")
    similarity_matrix = tf.matmul(all_graph_embeddings_tensor, all_smiles_embeddings_tensor, transpose_b=True).numpy()

    print("Rows: Graph Embeddings of [SMILES]\nCols: SMILES Embeddings of [SMILES]")
    print(" " * 10 + "".join([f"{s[:8]:<10}" for s in valid_smiles_for_eval])) 
    print("-" * (10 + len(valid_smiles_for_eval) * 10))

    for i, smiles_i in enumerate(valid_smiles_for_eval):
        row_str = f"{smiles_i[:8]:<10}" 
        for j in range(len(valid_smiles_for_eval)):\
            row_str += f"{similarity_matrix[i, j]:<10.4f}"
        print(row_str)
    print("-" * (10 + len(valid_smiles_for_eval) * 10))

    print("\n--- Key Observations ---")
    print("Expected: High values on the diagonal (positive pairs), low values off-diagonal (negative pairs).\n")
    print("This indicates that the pre-trained model successfully learns to align graph and SMILES representations of the same molecule.")
