In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import glob
import os

# --- CONFIGURATION FOR ESOL ---
DATASET_NAME = 'esol'
MOLECULENET_PATH = '/kaggle/input/moleculenet-tfrecords-final/moleculenet_tfrecords_final/'
GRASP_CHECKPOINT_PATH = '/kaggle/input/pretraining-checkpoints/pretraining_checkpoints/'
BATCH_SIZE = 64
EPOCHS = 1
LEARNING_RATE = 0.001
MAX_NODES = 419
NUM_ATOM_FEATURES = 5

# --- DATA PIPELINE ---

# The vocabulary must be identical to your pre-training script.
DUMMY_SMILES_FOR_VOCAB = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "c", "n", "=", "#", "(", ")", "[", "]", "@", "+", "-", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "H", "B", "b", "K", "k", "L", "l", "M", "m", "R", "r", "X", "x", "Y", "y", "Z", "z"] 
VOCAB = ['<pad>', '<unk>', '<cls>', '<eos>'] + sorted(list(set("".join(DUMMY_SMILES_FOR_VOCAB))))
CHAR_TO_IDX = {char: i for i, char in enumerate(VOCAB)}

def parse_fn(example):
    feature_description = {
        'atom_features': tf.io.FixedLenFeature([], tf.string), 'edge_index': tf.io.FixedLenFeature([], tf.string),
        'num_nodes': tf.io.FixedLenFeature([], tf.string), 'token_ids': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    atom_features = tf.io.parse_tensor(example['atom_features'], out_type=tf.float32)
    edge_index = tf.io.parse_tensor(example['edge_index'], out_type=tf.int32)
    num_nodes = tf.io.parse_tensor(example['num_nodes'], out_type=tf.int32)
    token_ids = tf.io.parse_tensor(example['token_ids'], out_type=tf.int32)
    label = tf.io.parse_tensor(example['label'], out_type=tf.float32)
    # For ESOL, the label shape is (1,)
    label = tf.reshape(label, [1])
    return (atom_features, edge_index, num_nodes, token_ids), label

@tf.function
def prepare_batch_for_model(features, label):
    atom_features, edge_index, num_nodes, token_ids = features
    atom_features_flat = tf.reshape(atom_features, (BATCH_SIZE * MAX_NODES, NUM_ATOM_FEATURES))
    num_nodes_squeezed = tf.squeeze(num_nodes, axis=-1)
    node_offsets = tf.cumsum(num_nodes_squeezed, exclusive=True)
    is_real_edge_mask = edge_index[:, :, 0] >= 0
    edge_batch_ids = tf.where(is_real_edge_mask)[:, 0]
    edge_batch_ids = tf.cast(edge_batch_ids, dtype=tf.int32)
    edge_offsets = tf.gather(node_offsets, edge_batch_ids)
    real_edges = tf.boolean_mask(edge_index, is_real_edge_mask)
    global_edge_index = real_edges + tf.expand_dims(edge_offsets, axis=-1)
    padding_mask = (token_ids != CHAR_TO_IDX['<pad>'])
    model_inputs = {
        'atom_features_input': atom_features_flat, 'edge_index_input': global_edge_index,
        'num_nodes_input': num_nodes_squeezed, 'token_ids_input': token_ids,
        'padding_mask_input': padding_mask
    }
    return model_inputs, label

def create_dataset(file_pattern, should_shuffle=False):
    files = glob.glob(file_pattern)
    if not files: return None
    dataset = tf.data.TFRecordDataset(files, num_parallel_reads=tf.data.AUTOTUNE).map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
    if should_shuffle: dataset = dataset.shuffle(buffer_size=1024)
    # The label shape for padded_batch is now hardcoded to 1 for ESOL
    dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=((tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), tf.TensorShape([None, 2]), tf.TensorShape([1]), tf.TensorShape([256])), tf.TensorShape([1])), drop_remainder=True)
    dataset = dataset.map(prepare_batch_for_model, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)


# --- Subclassed keras.Model ---
class GraspLinearProbeModel(keras.Model):
    def __init__(self, checkpoint_path, **kwargs):
        super().__init__(**kwargs)
        self.gin_model = tf.saved_model.load(os.path.join(checkpoint_path, 'gin_encoder_best'))
        self.transformer_model = tf.saved_model.load(os.path.join(checkpoint_path, 'transformer_encoder_best'))
        self.gin_function = self.gin_model.signatures['serving_default']
        self.transformer_function = self.transformer_model.signatures['serving_default']
        # The head now has 1 output unit and a linear activation for regression
        self.head = layers.Dense(1, activation='linear')
        self.concat = layers.Concatenate()
        
    def call(self, data, training=False):
        atom_feats, edge_index, num_nodes, token_ids, padding_mask = (
            data['atom_features_input'], data['edge_index_input'], data['num_nodes_input'],
            data['token_ids_input'], data['padding_mask_input']
        )
        edge_index_float = tf.cast(edge_index, tf.float32)
        num_nodes_float = tf.cast(num_nodes, tf.float32)
        token_ids_float = tf.cast(token_ids, tf.float32)
        padding_mask_float = tf.cast(padding_mask, tf.float32)
        
        graph_embedding = self.gin_function(inputs=atom_feats, inputs_1=edge_index_float, inputs_2=num_nodes_float)
        smiles_embedding = self.transformer_function(inputs=token_ids_float, inputs_1=padding_mask_float)
        
        concatenated_embeddings = self.concat([graph_embedding['output_0'], smiles_embedding['output_0']])
        return self.head(concatenated_embeddings)

    def train_step(self, data):
        inputs, y_true = data
        with tf.GradientTape() as tape:
            y_pred = self(inputs, training=True)
            loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
            
        trainable_vars = self.head.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        inputs, y_true = data
        y_pred = self(inputs, training=False)
        self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

# --- Main Execution ---
def run_linear_probing():
    print(f"--- 🚀 Starting Linear Probing for Dataset: {DATASET_NAME} ---")
    
    model = GraspLinearProbeModel(GRASP_CHECKPOINT_PATH)
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    
    # Compile the model with regression loss and metrics
    model.compile(optimizer=optimizer, 
                  loss='mean_squared_error', 
                  metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                           tf.keras.metrics.MeanAbsoluteError(name='mae')])
    
    train_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME.lower()}_train.tfrecord'), should_shuffle=True)
    valid_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME.lower()}_valid.tfrecord'))
    test_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME.lower()}_test.tfrecord'))
    if not train_ds: raise ValueError(f"Training TFRecord not found for {MOLECULENET_PATH}")

    print("\n--- Starting Training of the Linear Head ---")
    model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds, verbose=1)
    
    print("\n--- ✅ Training Finished ---")
    print("\n--- 🧪 Final Performance on Unseen Test Data ---")
    
    # Run evaluate to update the metric states
    model.evaluate(test_ds, verbose=0)

    # **THE FIX**: This robust loop checks the type of the metric result before printing.
    for metric in model.metrics:
        result = metric.result()
        # Check if the result is a dictionary (which happens for some metrics)
        if isinstance(result, dict):
            for key, value in result.items():
                print(f"  Final Test {key}: {value.numpy():.4f}")
        else:
            print(f"  Final Test {metric.name}: {result.numpy():.4f}")

# Run the pipeline
run_linear_probing()


--- 🚀 Starting Linear Probing for Dataset: esol ---

--- Starting Training of the Linear Head ---
[1m14/14[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 2s/step - mae: 1356.7762 - rmse: 16310.2812 - loss: 1244.9127 - val_loss: 8.5890

--- ✅ Training Finished ---

--- 🧪 Final Performance on Unseen Test Data ---
  Final Test loss: -53.9369
  Final Test rmse: 182.0762
  Final Test mae: 53.8660
