In [51]:
!pip install deepchem --quiet
!pip install rdkit --quiet
!pip install tqdm --quiet

In [52]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem
import deepchem as dc 
from tqdm import tqdm # For progress bars

# Suppress non-critical RDKit warnings
from rdkit import rdBase
rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

In [53]:
# ## 2. GPU Initialization
# Sets up the TensorFlow strategy for Kaggle GPU.

# %%
print("Checking for GPU support...")
try:
    # Try to detect GPU (NVIDIA T4/P100 on Kaggle)
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print(f"Found {len(gpus)} GPU(s) available.")
        # Use the default strategy, which will pick up the single GPU
        strategy = tf.distribute.get_strategy() 
        print(f"Using GPU strategy on: {gpus[0].name}")
    else:
        print("No GPU device found, defaulting to CPU.")
        strategy = tf.distribute.get_strategy() # Default to CPU strategy
except Exception as e:
    print(f"Error during GPU detection/setup: {e}")
    print("Defaulting to CPU.")
    strategy = tf.distribute.get_strategy()

print(f"Number of accelerators: {strategy.num_replicas_in_sync}")

Checking for GPU support...
No GPU device found, defaulting to CPU.
Number of accelerators: 1


In [54]:
# ## 3. Global Configuration
# Define global hyperparameters. These must be consistent with your pre-training notebook.

# %%
# --- Global Configuration (MUST match your pre-training notebook) ---
MAX_SMILES_LEN = 256 
MAX_NODES = 419 # This should be the MAX_NODES used during pre-training
NUM_ATOM_FEATURES = 5 # As defined by your atom_to_feature_vector
PROJECTION_DIM = 128
HIDDEN_DIM_GIN = 256 
EMBED_DIM_TRANSFORMER = 256 

BATCH_SIZE_PER_REPLICA = 64 # Batch size per GPU. Adjust based on T4/P100 memory.
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync # For 1 device, this is BATCH_SIZE_PER_REPLICA

# --- Paths for Pre-trained Encoders ---
# IMPORTANT: Update these paths to your new Kaggle dataset containing the saved encoders.
# Example: '/kaggle/input/your-pretrained-encoders-dataset-name/gin_encoder_best'
PRETRAINED_GIN_ENCODER_PATH = '/kaggle/input/grasp-saved-model/pretraining_checkpoints/gin_encoder_best' 
PRETRAINED_TRANSFORMER_ENCODER_PATH = '/kaggle/input/grasp-saved-model/pretraining_checkpoints/gin_transformer_best' 

# --- Path to Preprocessed MoleculeNet TFRecords ---
# If you ran preprocess_moleculenet.py on Kaggle and saved outputs, update this.
# If you haven't preprocessed MoleculeNet to TFRecords yet, this path won't be used,
# and the eager featurization will run.
MOLECULE_NET_TFRECORDS_DIR = '/kaggle/input/moleculenet-tfrecords/moleculenet_tfrecords' 
# Set to None if you want to use eager featurization from DeepChem directly:
# MOLECULE_NET_TFRECORDS_DIR = None

In [60]:
# ## 4. Helper Functions (Data Processing)
# These functions are crucial for featurizing SMILES to graph and tokenized forms.

# %%
# --- SMILES Tokenization ---
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

def tokenize_smiles(smiles, char_to_idx, max_len):
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion ---
def atom_to_feature_vector(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization()))
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

def smiles_to_tf_graph(smiles_string):
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)

    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])

    num_edges = len(edge_indices)
    if num_edges == 0:
        if num_nodes > 0:
            edge_indices_final = np.empty((0, 2), dtype=np.int32)
            num_edges_final = 0
        else:
            return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                    np.zeros((0, 2), dtype=np.int32),
                    0,
                    0)
    else:
        edge_indices_final = np.array(edge_indices, dtype=np.int32)
        num_edges_final = len(edge_indices_final)

    return node_features, edge_indices_final, num_nodes, num_edges_final

def featurize_smiles_and_graph_with_label(smiles_input, label): 
    # --- CRITICAL FIX: Ensure smiles_input is a string, converting from Mol if necessary ---
    # This ensures that `smiles_string` is always a Python string for `tokenize_smiles` and `smiles_to_tf_graph`.
    smiles_string = None # Initialize to None

    if isinstance(smiles_input, str):
        smiles_string = smiles_input
    elif isinstance(smiles_input, Chem.Mol): 
        try:
            smiles_string = Chem.MolToSmiles(smiles_input, canonical=True)
        except Exception:
            pass 
    # If smiles_input is None or other unexpected type, smiles_string remains None

    if smiles_string is None or len(smiles_string) == 0:
        # If the SMILES string is invalid or unconvertible, return dummy data
        dummy_node_features = np.zeros((MAX_NODES, NUM_ATOM_FEATURES), dtype=np.float32)
        dummy_edge_indices = np.zeros((0, 2), dtype=np.int32)
        dummy_num_nodes = 0
        dummy_num_edges = 0
        dummy_token_ids = np.zeros((MAX_SMILES_LEN,), dtype=np.int32) 
        dummy_mask = np.zeros((MAX_SMILES_LEN,), dtype=np.bool_) 
        return (dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, dummy_token_ids, dummy_mask, label)
    
    # Now, smiles_string is guaranteed to be a valid string
    token_ids = tokenize_smiles(smiles_string, char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string)

    if num_nodes == 0: 
        dummy_node_features = np.zeros((MAX_NODES, NUM_ATOM_FEATURES), dtype=np.float32)
        dummy_edge_indices = np.zeros((0, 2), dtype=np.int32)
        dummy_num_nodes = 0
        dummy_num_edges = 0
        dummy_token_ids = np.zeros((MAX_SMILES_LEN,), dtype=np.int32) 
        dummy_mask = np.zeros((MAX_SMILES_LEN,), dtype=np.bool_) 
        return (dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, dummy_token_ids, dummy_mask, label)
    
    padded_node_features = np.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return (padded_node_features, edge_indices, num_nodes, num_edges, token_ids, mask, label)

# --- Global Vocabulary Building (Needed for tokenization) ---
print("Building vocabulary for MoleculeNet data...")
dummy_smiles_for_vocab_build = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "c", "n", "=", "#", "(", ")", "[", "]", "@", "+", "-", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "H", "B", "b", "K", "k", "L", "l", "M", "m", "R", "r", "X", "x", "Y", "y", "Z", "z"] 
_, char_to_idx, _ = build_smiles_vocab(dummy_smiles_for_vocab_build)
VOCAB_SIZE = len(char_to_idx) 

Building vocabulary for MoleculeNet data...
Built vocabulary of size: 49


In [61]:
# ## 6. Model Architecture (Re-define for Loading)
# These classes must be IDENTICAL to how they were defined in your 2_GRASP_Pretraining.ipynb.

# %%
# GIN Layer
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs 
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)
        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)
        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)
        if self.activation is not None:
            output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.output_dim

# GIN Encoder
class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, num_node_features, **kwargs): 
        super(GINEncoder, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim 
        self.initial_mlp = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(hidden_dim) 
        ])
        self.gin_layers = []
        self.bns = []
        for i in range(num_layers): 
            self.gin_layers.append(GINLayer(hidden_dim, activation='relu' if i < num_layers - 1 else None))
            self.bns.append(layers.BatchNormalization())
    
    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs
        x = self.initial_mlp(node_features)
        for i in range(len(self.gin_layers)): 
            x = self.gin_layers[i]((x, edge_indices, num_nodes))
            x = self.bns[i](x)
        batch_size = tf.shape(node_features)[0] // MAX_NODES
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, self.hidden_dim)) 
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32) 
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1) 
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1)
        return graph_embedding

# Transformer Encoder
class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )
        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False, mask=None):
        token_ids, padding_mask_bool = inputs 
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]
        attention_mask_additive = tf.cast(padding_mask_bool, dtype=tf.float32) * -1e9
        attention_mask_additive = tf.expand_dims(attention_mask_additive, axis=1) 
        
        for i, (attention, norm1, ff_dense1, ff_dense2, norm2) in enumerate(self.encoder_layers):
            attn_output = attention(x, x, attention_mask=attention_mask_additive, training=training)
            x = norm1(x + attn_output)
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output)
        
        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask_bool, axis=-1), dtype=x.dtype)
        non_padded_mask = 1.0 - expanded_padding_mask
        x_masked = x * non_padded_mask
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9)
        
        return self.final_norm(smiles_embedding)

# Projection Head
class ProjectionHead(keras.Model):
    def __init__(self, input_dim, output_dim, hidden_dim=256, **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.net = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(output_dim)
        ])

    def call(self, x):
        return self.net(x)

# Downstream Model
class DownstreamModel(keras.Model):
    def __init__(self, pre_trained_gin_encoder, pre_trained_transformer_encoder, output_dim, task_type='classification', **kwargs):
        super(DownstreamModel, self).__init__(**kwargs)
        self.gin_encoder = pre_trained_gin_encoder
        self.transformer_encoder = pre_trained_transformer_encoder
        
        # Freeze the pre-trained encoders initially to use them as feature extractors.
        self.gin_encoder.trainable = False
        self.transformer_encoder.trainable = False

        input_to_head_dim = HIDDEN_DIM_GIN + EMBED_DIM_TRANSFORMER 

        self.classifier_head = keras.Sequential([
            layers.Dense(input_to_head_dim // 2, activation='relu'), 
            layers.Dropout(0.3), 
            layers.Dense(output_dim, activation='sigmoid' if task_type == 'classification' else 'linear')
        ])
        self.task_type = task_type

    def call(self, inputs, training=False):
        node_features_padded, edge_indices_padded, num_nodes, num_edges, token_ids, smiles_mask = inputs

        node_features_flat = tf.reshape(node_features_padded, (-1, tf.shape(node_features_padded)[2]))
        
        batch_size = tf.shape(node_features_padded)[0] 
        edge_mask = tf.sequence_mask(num_edges, maxlen=tf.shape(edge_indices_padded)[1], dtype=tf.bool)
        valid_edge_indices = tf.cast(tf.boolean_mask(edge_indices_padded, edge_mask), dtype=tf.int32) 
        batch_ids_for_edges = tf.cast(tf.where(edge_mask)[:, 0], dtype=tf.int32)
        node_offsets_for_edges = tf.range(batch_size) * MAX_NODES
        offsets = tf.gather(node_offsets_for_edges, batch_ids_for_edges)
        offsets = tf.expand_dims(offsets, axis=-1)
        global_edge_indices_filtered = valid_edge_indices + offsets
        
        graph_embeddings = self.gin_encoder((node_features_flat, global_edge_indices_filtered, num_nodes), training=training)
        smiles_embeddings = self.transformer_encoder((token_ids, smiles_mask), training=training)

        combined_embeddings = tf.concat([graph_embeddings, smiles_embeddings], axis=-1)

        output = self.classifier_head(combined_embeddings, training=training)
        return output

In [62]:
# ## 6. Load Pre-trained Encoders
# Loads the saved encoders from the pre-training phase.

# %%
with strategy.scope(): 
    pre_trained_gin_encoder = tf.saved_model.load(PRETRAINED_GIN_ENCODER_PATH)
    pre_trained_transformer_encoder = tf.saved_model.load(PRETRAINED_TRANSFORMER_ENCODER_PATH)

print("Pre-trained GIN Encoder and Transformer Encoder loaded successfully.")

Pre-trained GIN Encoder and Transformer Encoder loaded successfully.


In [63]:
# ## 7. Data Preparation for Downstream Tasks (MoleculeNet)
# This section handles loading and eager featurization of MoleculeNet datasets.

# %%
# Define the featurization function for a single sample (returns flat tuple of NumPy arrays)
def featurize_smiles_and_graph_with_label(smiles_string, label):
    # This function will be called in eager mode (Python loop)
    token_ids = tokenize_smiles(smiles_string, char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string)

    if num_nodes == 0:
        # Return dummy NumPy arrays with correct shapes and dtypes for filtering
        dummy_node_features = np.zeros((MAX_NODES, NUM_ATOM_FEATURES), dtype=np.float32)
        dummy_edge_indices = np.zeros((0, 2), dtype=np.int32)
        dummy_num_nodes = 0
        dummy_num_edges = 0
        dummy_token_ids = np.zeros((MAX_SMILES_LEN,), dtype=np.int32)
        dummy_mask = np.zeros((MAX_SMILES_LEN,), dtype=np.bool_)
        return (dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, dummy_token_ids, dummy_mask, label) # Include label in flat return
    
    padded_node_features = np.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    # Return a flat tuple of NumPy arrays (will be converted to tf.Tensor by from_tensor_slices)
    return (padded_node_features, edge_indices, num_nodes, num_edges, token_ids, mask, label)


# Function to convert DeepChem dataset to TensorFlow dataset
def dc_dataset_to_tf_dataset_for_downstream(dc_dataset, batch_size, shuffle_buffer_size=1000):
    smiles_list = dc_dataset.X.tolist() 
    labels = dc_dataset.y 

    # Handle multi-task labels if necessary (e.g., Tox21 has multiple labels per sample)
    # Ensure labels are float32 for BCELoss
    labels_dtype = tf.float32 if labels.dtype == np.bool_ or labels.dtype == np.int_ else labels.dtype
    label_shape = labels.shape[1:] if labels.ndim > 1 else ()

    print(f"Pre-featurizing {len(smiles_list)} samples for DeepChem dataset...")
    
    all_processed_inputs = []
    all_processed_labels = []
    
    for i in tqdm(range(len(smiles_list)), desc="Eager Featurization"):
        smiles_str = smiles_list[i]
        label = labels[i]
        
        # Call the featurization function directly
        # featurize_smiles_and_graph_with_label returns a flat tuple of 7 NumPy arrays
        processed_data_flat = featurize_smiles_and_graph_with_label(smiles_str, label)
            
        # Filter out invalid samples immediately (num_nodes is at index 2 in the flat tuple)
        if processed_data_flat[2] > 0: 
            # Append the flat tuple of NumPy arrays
            all_processed_inputs.append(processed_data_flat[:6]) # Append the 6 input components
            all_processed_labels.append(processed_data_flat[6])  # Append the label component
    
    print(f"Finished eager featurization. {len(all_processed_inputs)} valid samples processed.")
    
    # Create tf.data.Dataset from these already processed NumPy arrays
    # from_tensor_slices is suitable here as all_processed_inputs is a list of tuples of NumPy arrays.
    dataset = tf.data.Dataset.from_tensor_slices((all_processed_inputs, all_processed_labels))
    
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    # Define padded shapes for the loaded data
    padded_shapes_inputs_tf = (
        tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), # node_features
        tf.TensorShape([None, 2]),                     # edge_indices
        tf.TensorShape([]),                           # num_nodes
        tf.TensorShape([]),                           # num_edges
        tf.TensorSpec(shape=[MAX_SMILES_LEN], dtype=tf.int32).shape, # token_ids, use .shape
        tf.TensorSpec(shape=[MAX_SMILES_LEN], dtype=tf.bool).shape # smiles_mask, use .shape
    )
    padding_values_inputs_tf = (
        tf.constant(0.0, dtype=tf.float32),
        tf.constant(0, dtype=tf.int32),
        tf.constant(0, dtype=tf.int32),
        tf.constant(0, dtype=tf.int32),
        tf.constant(char_to_idx['<pad>'], dtype=tf.int32),
        tf.constant(False, dtype=tf.bool) 
    )

    # Combine input and label shapes for padded_batch
    padded_shapes_full = (padded_shapes_inputs_tf, label_shape)
    padding_values_full = (padding_values_inputs_tf, tf.constant(0, dtype=labels_dtype))
    
    dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes_full, padding_values=padding_values_full, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [64]:
# ## 8. Benchmarking Loop for BBBP, Tox21, ESOL
# Executes fine-tuning and evaluation for each MoleculeNet task.

# %%
# Define the global `char_to_idx` and `VOCAB_SIZE` first.
# You need to ensure this matches the vocab built during pre-training.
# For a full setup, you would save the vocab from pre-training and load it here.
# For this example, we'll build it from a small set, assuming it's representative enough
# for the characters present in MoleculeNet, but this is a potential source of errors
# if MoleculeNet has chars not in this small set.
# Ideally, load your pre-training vocab.
print("Building vocabulary for MoleculeNet data...")
dummy_smiles_for_vocab_build = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "c", "n", "=", "#", "(", ")", "[", "]", "@", "+", "-", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "H", "B", "b", "K", "k", "L", "l", "M", "m", "R", "r", "X", "x", "Y", "y", "Z", "z"] 
_, char_to_idx, _ = build_smiles_vocab(dummy_smiles_for_vocab_build)
VOCAB_SIZE = len(char_to_idx) 


# --- BBBP (Blood-Brain Barrier Permeability) ---
print("\n--- Benchmarking BBBP ---")
bbbp_tasks, bbbp_datasets, bbbp_transformers = dc.molnet.load_bbbp(featurizer='Raw', splitter='scaffold')
train_bbbp, valid_bbbp, test_bbbp = bbbp_datasets
bbbp_output_dim = len(bbbp_tasks) 

train_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(train_bbbp, GLOBAL_BATCH_SIZE)
valid_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(valid_bbbp, GLOBAL_BATCH_SIZE)
test_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(test_bbbp, GLOBAL_BATCH_SIZE)

with strategy.scope():
    bbbp_model = DownstreamModel(pre_trained_gin_encoder, pre_trained_transformer_encoder, 
                                 output_dim=bbbp_output_dim, task_type='classification')
    bbbp_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                       loss=keras.losses.BinaryCrossentropy(from_logits=False),
                       metrics=[keras.metrics.AUC(name='auc'), keras.metrics.BinaryAccuracy(name='accuracy')])
    
    # Build the model with a dummy input shape for the first call
    dummy_input_for_build = (
        tf.zeros((GLOBAL_BATCH_SIZE, MAX_NODES, NUM_ATOM_FEATURES), dtype=tf.float32), 
        tf.zeros((GLOBAL_BATCH_SIZE, 10, 2), dtype=tf.int32), 
        tf.constant([50]*GLOBAL_BATCH_SIZE, dtype=tf.int32),  
        tf.constant([10]*GLOBAL_BATCH_SIZE, dtype=tf.int32),  
        tf.zeros((GLOBAL_BATCH_SIZE, MAX_SMILES_LEN), dtype=tf.int32),
        tf.zeros((GLOBAL_BATCH_SIZE, MAX_SMILES_LEN), dtype=tf.bool)
    )
    _ = bbbp_model(dummy_input_for_build, training=False)
    print("BBBP DownstreamModel built successfully.")

print("Fine-tuning BBBP model...")
FINE_TUNE_EPOCHS = 10 
bbbp_model.fit(train_bbbp_tf, epochs=FINE_TUNE_EPOCHS, validation_data=valid_bbbp_tf)

print("\nEvaluating BBBP model on test set...")
bbbp_results = bbbp_model.evaluate(test_bbbp_tf)
print(f"BBBP Test Loss: {bbbp_results[0]:.4f}, Test AUC: {bbbp_results[1]:.4f}, Test Accuracy: {bbbp_results[2]:.4f}")


# --- Tox21 ---
print("\n--- Benchmarking Tox21 ---")
tox21_tasks, tox21_datasets, tox21_transformers = dc.molnet.load_tox21(featurizer='Raw', splitter='scaffold')
train_tox21, valid_tox21, test_tox21 = tox21_datasets
tox21_output_dim = len(tox21_tasks) 

train_tox21_tf = dc_dataset_to_tf_dataset_for_downstream(tox21_train, GLOBAL_BATCH_SIZE)
valid_tox21_tf = dc_dataset_to_tf_dataset_for_downstream(tox21_valid, GLOBAL_BATCH_SIZE)
test_tox21_tf = dc_dataset_to_tf_dataset_for_downstream(tox21_test, GLOBAL_BATCH_SIZE)

with strategy.scope():
    tox21_model = DownstreamModel(pre_trained_gin_encoder, pre_trained_transformer_encoder, 
                                  output_dim=tox21_output_dim, task_type='classification')
    tox21_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                        loss=keras.losses.BinaryCrossentropy(from_logits=False),
                        metrics=[keras.metrics.AUC(name='auc', multi_label=True), keras.metrics.BinaryAccuracy(name='accuracy')])
    
    _ = tox21_model(dummy_input_for_build, training=False)
    print("Tox21 DownstreamModel built successfully.")

print("Fine-tuning Tox21 model...")
tox21_model.fit(train_tox21_tf, epochs=FINE_TUNE_EPOCHS, validation_data=valid_tox21_tf)

print("\nEvaluating Tox21 model on test set...")
tox21_results = tox21_model.evaluate(test_tox21_tf)
print(f"Tox21 Test Loss: {tox21_results[0]:.4f}, Test AUC: {tox21_results[1]:.4f}, Test Accuracy: {tox21_results[2]:.4f}")


# --- ESOL (Solubility Estimation) ---
print("\n--- Benchmarking ESOL ---")
esol_tasks, esol_datasets, esol_transformers = dc.molnet.load_esol(featurizer='Raw', splitter='scaffold')
train_esol, valid_esol, test_esol = esol_datasets
esol_output_dim = len(esol_tasks) 

train_esol_tf = dc_dataset_to_tf_dataset_for_downstream(esol_train, GLOBAL_BATCH_SIZE)
valid_esol_tf = dc_dataset_to_tf_dataset_for_downstream(esol_valid, GLOBAL_BATCH_SIZE)
test_esol_tf = dc_dataset_to_tf_dataset_for_downstream(esol_test, GLOBAL_BATCH_SIZE)

with strategy.scope():
    esol_model = DownstreamModel(pre_trained_gin_encoder, pre_trained_transformer_encoder, 
                                 output_dim=esol_output_dim, task_type='regression')
    esol_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                       loss=keras.losses.MeanSquaredError(),
                       metrics=[keras.metrics.MeanAbsoluteError(name='mae')])
    
    _ = esol_model(dummy_input_for_build, training=False)
    print("ESOL DownstreamModel built successfully.")

print("Fine-tuning ESOL model...")
esol_model.fit(train_esol_tf, epochs=FINE_TUNE_EPOCHS, validation_data=valid_esol_tf)

print("\nEvaluating ESOL model on test set...")
esol_results = esol_model.evaluate(test_esol_tf)
print(f"ESOL Test Loss: {esol_results[0]:.4f}, Test MAE: {esol_results[1]:.4f}")


print("\nAll MoleculeNet benchmarking complete!")

Building vocabulary for MoleculeNet data...
Built vocabulary of size: 49

--- Benchmarking BBBP ---
Pre-featurizing 1631 samples for DeepChem dataset...


Eager Featurization:   0%|          | 0/1631 [00:00<?, ?it/s]


TypeError: 'Mol' object is not iterable