In [None]:
# --- Imports ---
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem
import deepchem as dc # For MoleculeNet datasets

# --- 0. TPU/GPU/CPU Initialization ---
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    print('Not running on TPU, defaulting to GPU/CPU.')
    tpu = None
    strategy = tf.distribute.get_strategy() # Default to GPU or CPU strategy

print(f"Number of accelerators: {strategy.num_replicas_in_sync}")


# --- 1. Global Configuration (Must match pre-training notebook) ---
# These values MUST be consistent with your 2_GRASP_Pretraining.ipynb notebook.
# If they are different, the loaded models/data processing will be incompatible.

MAX_SMILES_LEN = 256
MAX_NODES = 419 # Or whatever max_nodes_found was in your 1_PubChem_SMILES_Extraction.ipynb output
NUM_ATOM_FEATURES = 5 # As defined in your atom_to_feature_vector
PROJECTION_DIM = 128
HIDDEN_DIM_GIN = 256 
EMBED_DIM_TRANSFORMER = 256 


# --- 2. Load Pre-trained Encoders ---
# Assuming your 2_GRASP_Pretraining.ipynb saved these to /kaggle/working/
# When you commit notebook 2, these will become available as a Kaggle dataset.
# You will need to "Add Data" to this notebook (3_GRASP_Finetune_BBBP.ipynb)
# and select the output of your 2_GRASP_Pretraining.ipynb.
# The path will typically be something like /kaggle/input/name-of-your-notebook2-commit/
# e.g., /kaggle/input/grasp-pretraining-model-v1/gin_encoder_pretrained
# Make sure these paths are correct after you add the data!

PRETRAINED_GIN_ENCODER_PATH = '/kaggle/input/your-2-pretraining-notebook-output/gin_encoder_pretrained' # <--- UPDATE THIS PATH
PRETRAINED_TRANSFORMER_ENCODER_PATH = '/kaggle/input/your-2-pretraining-notebook-output/transformer_encoder_pretrained' # <--- UPDATE THIS PATH

with strategy.scope(): # Load within strategy scope for TPU compatibility
    # Custom objects are required if your model layers are custom Keras Layers
    # like GINLayer. You must pass them when loading.
    custom_objects = {
        'GINLayer': GINLayer, # Define GINLayer class here as well if not already
        'GINEncoder': GINEncoder, # Define GINEncoder class here as well if not already
        'TransformerEncoder': TransformerEncoder, # Define TransformerEncoder class here as well if not already
        # Add other custom layers/losses if any from your pre-training model
    }
    
    # Load the saved encoders
    # Ensure all original classes (GINLayer, GINEncoder, TransformerEncoder) are defined
    # in this notebook before loading, even if they aren't directly instantiated,
    # as tf.saved_model.load needs their definitions to reconstruct the graph.
    pre_trained_gin_encoder = tf.saved_model.load(PRETRAINED_GIN_ENCODER_PATH, options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost'))
    pre_trained_transformer_encoder = tf.saved_model.load(PRETRAINED_TRANSFORMER_ENCODER_PATH, options=tf.saved_model.LoadOptions(experimental_io_device='/job:localhost'))

print("Pre-trained GIN Encoder and Transformer Encoder loaded successfully.")

# --- Define Helper Functions (Copied from pre-training notebook) ---
# These functions are needed to preprocess the MoleculeNet data in the same way
# your pre-trained encoders expect.

# --- SMILES Tokenization ---
# A simple character-level tokenizer for demonstration.
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    return vocab, char_to_idx, idx_to_char

# Temporarily build vocab from a small sample for featurization logic check
# In a real scenario, you'd load the full vocab saved from pre-training
# or ensure a consistent vocab build process.
# For simplicity, we assume your pre-training vocab contains all chars needed.
# For full reproducibility, save/load the vocab from pre-training.
dummy_smiles_for_vocab = ["C", "CC", "CCC"] # Small dummy for local function def
vocab, char_to_idx, idx_to_char = build_smiles_vocab(dummy_smiles_for_vocab)
VOCAB_SIZE = len(vocab)


def tokenize_smiles(smiles, char_to_idx, max_len):
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion Helpers ---
NUM_ATOM_FEATURES = 5 # Must match your pre-training notebook

def atom_to_feature_vector(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization()))
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

def bond_to_feature_vector(bond):
    features = []
    features.append(int(bond.GetBondType()))
    features.append(int(bond.GetIsConjugated()))
    return np.array(features, dtype=np.float32)

def smiles_to_tf_graph(smiles_string):
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return None, None, None, None 

    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features:
        return None, None, None, None 
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])

    num_edges = len(edge_indices) 
    if num_edges == 0: 
        if num_nodes == 1:
            edge_indices_final = np.array([[0, 0]], dtype=np.int32)
            num_edges_final = 1
        else:
            edge_indices_final = tf.zeros((0, 2), dtype=tf.int32)
            num_edges_final = 0
        return node_features, edge_indices_final, num_nodes, num_edges_final
    
    edge_indices_final = np.array(edge_indices, dtype=np.int32)
    num_edges_final = len(edge_indices_final)

    return node_features, edge_indices_final, num_nodes, num_edges_final


# --- 3. Data Preparation for BBBP Downstream Task ---

# Define the featurization function for the dataset map operation
# This is a slightly modified version for DownstreamModel input (adds label)
def featurize_smiles_and_graph_with_label(smiles_string, label):
    # This function will run on CPU
    token_ids = tokenize_smiles(smiles_string.numpy().decode('utf-8'), char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string.numpy().decode('utf-8'))

    # Handle cases where graph conversion fails
    if node_features is None:
        # Return dummy values that will be filtered out later
        dummy_node_features = tf.zeros((1, NUM_ATOM_FEATURES), dtype=tf.float32)
        dummy_edge_indices = tf.zeros((0, 2), dtype=tf.int32)
        dummy_num_nodes = tf.constant(0, dtype=tf.int32)
        dummy_num_edges = tf.constant(0, dtype=tf.int32)
        return (dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, token_ids, mask), label
    
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return ((tf.constant(padded_node_features, dtype=tf.float32),
             tf.constant(edge_indices, dtype=tf.int32),
             tf.constant(num_nodes, dtype=tf.int32),
             tf.constant(num_edges, dtype=tf.int32),
             tf.constant(token_ids, dtype=tf.int32),
             tf.constant(mask, dtype=tf.bool)),
            label) # Return inputs as a tuple and label separately


# Need to define padded_shapes and padding_values for tf.data.Dataset.padded_batch
# These should mirror the definitions in the pre-training notebook for the input features
padded_shapes_inputs = (
    tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]),
    tf.TensorShape([None, 2]),
    tf.TensorShape([]),
    tf.TensorShape([]),
    tf.TensorShape([MAX_SMILES_LEN]),
    tf.TensorShape([MAX_SMILES_LEN])
)
padding_values_inputs = (
    tf.constant(0.0, dtype=tf.float32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(char_to_idx['<pad>'], dtype=tf.int32),
    tf.constant(True, dtype=tf.bool)
)

# Function to convert DeepChem dataset to TensorFlow dataset
def dc_dataset_to_tf_dataset_for_downstream(dc_dataset, batch_size, shuffle_buffer_size=1000):
    smiles_list = dc_dataset.X.tolist() # Get SMILES strings
    labels = dc_dataset.y # Get labels. Shape is (num_samples, num_tasks)

    # Use tf.data.Dataset.from_generator to handle flexible data loading
    # Generator yields (inputs_tuple, label)
    def generator():
        for i in range(len(smiles_list)):
            yield smiles_list[i], labels[i] # Yield raw SMILES and label first

    # Create dataset from generator, then map featurization
    dataset = tf.data.Dataset.from_generator(
        generator,
        output_signature=(tf.TensorSpec(shape=(), dtype=tf.string), # Raw SMILES
                          tf.TensorSpec(shape=labels.shape[1:] if labels.ndim > 1 else (), dtype=labels.dtype)) # Label
    )
    
    # Map the featurization function
    dataset = dataset.map(lambda smiles, label: tf.py_function(
        featurize_smiles_and_graph_with_label,
        inp=[smiles, label],
        Tout=(padded_shapes_inputs.as_dict(), tf.TensorSpec(shape=labels.shape[1:] if labels.ndim > 1 else (), dtype=labels.dtype).as_dict()) # Output signature for tuple and label
    ), num_parallel_calls=tf.data.AUTOTUNE)

    # Filter out failed graph conversions (where num_nodes was 0)
    dataset = dataset.filter(lambda inputs_tuple, label: inputs_tuple[2] > 0) # inputs_tuple[2] is num_nodes

    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

    # Define full padded shapes and padding values including the label
    padded_shapes_full = (padded_shapes_inputs, tf.TensorShape(labels.shape[1:] if labels.ndim > 1 else ()))
    padding_values_full = (padding_values_inputs, tf.constant(0, dtype=labels.dtype))
    
    dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes_full, padding_values=padding_values_full, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


# Load BBBP dataset
# 'Raw' featurizer is used to get raw SMILES strings for your custom featurizer
bbbp_tasks, bbbp_datasets, bbbp_transformers = dc.molnet.load_bbbp(featurizer='Raw', splitter='scaffold')
train_bbbp, valid_bbbp, test_bbbp = bbbp_datasets

print(f"BBBP Train Samples: {len(train_bbbp)}")
print(f"BBBP Validation Samples: {len(valid_bbbp)}")
print(f"BBBP Test Samples: {len(test_bbbp)}")

# Convert DeepChem datasets to TensorFlow datasets
GLOBAL_BATCH_SIZE = 64 * strategy.num_replicas_in_sync # Use the same GLOBAL_BATCH_SIZE as pre-training
train_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(train_bbbp, GLOBAL_BATCH_SIZE)
valid_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(valid_bbbp, GLOBAL_BATCH_SIZE)
test_bbbp_tf = dc_dataset_to_tf_dataset_for_downstream(test_bbbp, GLOBAL_BATCH_SIZE)


# --- 4. Define Downstream Model ---
# You need to redefine GINLayer, GINEncoder, TransformerEncoder classes here
# (or import them if saved in a separate utility script)
# for `tf.saved_model.load` and Keras to correctly build the graph.

# GIN Layer (Custom Keras Layer) - RE-DEFINE
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs # num_nodes_batch is unused here, but passed
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)
        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)
        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)
        if self.activation is not None:
            output = self.activation(output)
        return output

    def compute_output_shape(self, input_shape):
        return input_shape[0][0], self.output_dim

# GIN Encoder - RE-DEFINE
class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, num_node_features, **kwargs):
        super(GINEncoder, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim 
        self.initial_mlp = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(hidden_dim) 
        ])
        self.gin_layers = []
        self.bns = []
        for i in range(num_layers): 
            self.gin_layers.append(GINLayer(hidden_dim, activation='relu' if i < num_layers - 1 else None))
            self.bns.append(layers.BatchNormalization())
    
    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs
        x = self.initial_mlp(node_features)
        for i in range(len(self.gin_layers)): 
            x = self.gin_layers[i]((x, edge_indices, num_nodes))
            x = self.bns[i](x)
        batch_size = tf.shape(node_features)[0] // MAX_NODES
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, self.hidden_dim))
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32) 
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1) 
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1)
        return graph_embedding

# Transformer Encoder - RE-DEFINE
class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )
        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False, mask=None):
        token_ids, padding_mask_bool = inputs 
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]
        attention_mask_additive = tf.cast(padding_mask_bool, dtype=tf.float32) * -1e9
        attention_mask_additive = tf.expand_dims(attention_mask_additive, axis=1) 
        
        for i, (attention, norm1, ff_dense1, ff_dense2, norm2) in enumerate(self.encoder_layers):
            attn_output = attention(x, x, attention_mask=attention_mask_additive, training=training)
            x = norm1(x + attn_output)
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output)
        
        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask_bool, axis=-1), dtype=x.dtype)
        non_padded_mask = 1.0 - expanded_padding_mask
        x_masked = x * non_padded_mask
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9)
        
        return self.final_norm(smiles_embedding)


# Downstream Model
class DownstreamModel(keras.Model):
    def __init__(self, pre_trained_gin_encoder, pre_trained_transformer_encoder, output_dim, task_type='classification', **kwargs):
        super(DownstreamModel, self).__init__(**kwargs)
        self.gin_encoder = pre_trained_gin_encoder
        self.transformer_encoder = pre_trained_transformer_encoder
        
        # Freeze the pre-trained encoders initially
        # You might unfreeze them after a few epochs for full fine-tuning
        self.gin_encoder.trainable = False
        self.transformer_encoder.trainable = False

        # Head for classification/regression
        input_to_head_dim = PROJECTION_DIM + EMBED_DIM_TRANSFORMER # Combined dimensions of the *outputs* from encoders
                                                                 # before their original projection heads.
                                                                 # Adjust based on whether you take raw encoder outputs
                                                                 # or their *projected* outputs for fine-tuning.
                                                                 # Assuming PROJECTION_DIM for both after pre-training projection

        self.classifier_head = keras.Sequential([
            layers.Dense(input_to_head_dim // 2, activation='relu'),
            layers.Dropout(0.3),
            layers.Dense(output_dim, activation='sigmoid' if task_type == 'classification' else 'linear')
        ])
        self.task_type = task_type

    def call(self, inputs, training=False):
        node_features_padded, edge_indices_padded, num_nodes, num_edges, token_ids, smiles_mask = inputs

        # Flatten node_features_padded (must match GRASPModel.call preprocessing)
        node_features_flat = tf.reshape(node_features_padded, (-1, tf.shape(node_features_padded)[2]))
        
        batch_size = tf.shape(node_features_padded)[0] 
        
        # Edge preprocessing (must match GRASPModel.call preprocessing)
        edge_mask = tf.sequence_mask(num_edges, maxlen=tf.shape(edge_indices_padded)[1], dtype=tf.bool)
        valid_edge_indices = tf.boolean_mask(edge_indices_padded, edge_mask)
        batch_ids_for_edges = tf.cast(tf.where(edge_mask)[:, 0], dtype=tf.int32)
        node_offsets_for_edges = tf.range(batch_size) * MAX_NODES
        offsets = tf.gather(node_offsets_for_edges, batch_ids_for_edges)
        offsets = tf.expand_dims(offsets, axis=-1)
        global_edge_indices_filtered = valid_edge_indices + offsets
        
        # Get embeddings from pre-trained encoders
        # Note: training=training is important to pass through for BatchNorm/Dropout
        graph_embeddings = self.gin_encoder((node_features_flat, global_edge_indices_filtered, num_nodes), training=training)
        smiles_embeddings = self.transformer_encoder((token_ids, smiles_mask), training=training)

        # Concatenate embeddings
        combined_embeddings = tf.concat([graph_embeddings, smiles_embeddings], axis=-1)

        # Pass through classification/regression head
        output = self.classifier_head(combined_embeddings, training=training)
        return output

# --- 5. Fine-tuning and Evaluation Loop (BBBP) ---

# BBBP task specific output dimension (binary classification)
bbbp_output_dim = len(bbbp_tasks) # Should be 1 for BBBP

with strategy.scope():
    # Instantiate Downstream Model
    bbbp_model = DownstreamModel(pre_trained_gin_encoder, pre_trained_transformer_encoder, 
                                 output_dim=bbbp_output_dim, task_type='classification')

    # Compile the model
    bbbp_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                       loss=keras.losses.BinaryCrossentropy(from_logits=False),
                       metrics=[keras.metrics.AUC(name='auc')])

    # Build the model with a dummy input shape to ensure it's built before fit, especially for custom layers
    dummy_input_shape = (GLOBAL_BATCH_SIZE, MAX_NODES, NUM_ATOM_FEATURES) # for node_features_padded
    dummy_input = (tf.zeros(dummy_input_shape), 
                   tf.zeros((GLOBAL_BATCH_SIZE, 10, 2), dtype=tf.int32), # dummy edges
                   tf.constant([50]*GLOBAL_BATCH_SIZE, dtype=tf.int32), # dummy num_nodes
                   tf.constant([10]*GLOBAL_BATCH_SIZE, dtype=tf.int32), # dummy num_edges
                   tf.zeros((GLOBAL_BATCH_SIZE, MAX_SMILES_LEN), dtype=tf.int32),
                   tf.zeros((GLOBAL_BATCH_SIZE, MAX_SMILES_LEN), dtype=tf.bool))
    
    # Try to call the model with dummy inputs to build it
    try:
        _ = bbbp_model(dummy_input)
        print("BBBP DownstreamModel built successfully with dummy input.")
    except Exception as e:
        print(f"Error building BBBP DownstreamModel with dummy input: {e}")
        # If building fails, training will likely fail. You might need to debug
        # the model's call method or input shapes.


print("\nFine-tuning BBBP model...")
FINE_TUNE_EPOCHS = 10 # Adjust as needed
bbbp_model.fit(train_bbbp_tf, epochs=FINE_TUNE_EPOCHS, validation_data=valid_bbbp_tf)

print("\nEvaluating BBBP model on test set...")
bbbp_results = bbbp_model.evaluate(test_bbbp_tf)
print(f"BBBP Test Loss: {bbbp_results[0]:.4f}, Test AUC: {bbbp_results[1]:.4f}")


# --- Optional: Save Fine-tuned Model (if desired) ---
# bbbp_model.save('bbbp_finetuned_model')
# print("BBBP fine-tuned model saved.")

print("\nBBBP fine-tuning complete!")
