In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import glob
import os

#  CONFIGURATION 
DATASET_NAME = 'bbbp'
MOLECULENET_PATH = '/kaggle/input/moleculenet-tfrecords-v2/moleculenet_tfrecords_v2/'
GRASP_CHECKPOINT_PATH = '/kaggle/input/pretraining-checkpoints/pretraining_checkpoints/'
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 0.001
MAX_NODES = 419
NUM_ATOM_FEATURES = 5

#  DATA PIPELINE 
def masked_binary_crossentropy(y_true, y_pred):
    mask = tf.math.is_finite(y_true); y_true_masked = tf.where(mask, y_true, tf.zeros_like(y_true)); y_pred_masked = tf.where(mask, y_pred, tf.zeros_like(y_pred)); bce = tf.keras.losses.binary_crossentropy(y_true_masked, y_pred_masked); masked_bce = tf.boolean_mask(bce, mask); return tf.reduce_mean(masked_bce)
TASK_SPECS = {'tox21': {'type': 'classification', 'output_units': 12, 'loss_fn': masked_binary_crossentropy, 'metrics': [tf.keras.metrics.AUC(name='auc', multi_label=True, num_labels=12)]}, 'bbbp': {'type': 'classification', 'output_units': 1, 'loss_fn': 'binary_crossentropy', 'metrics': [tf.keras.metrics.AUC(name='auc')]}, 'esol': {'type': 'regression', 'output_units': 1, 'loss_fn': 'mean_squared_error', 'metrics': [tf.keras.metrics.RootMeanSquaredError(name='rmse'), 'mae']}}

def parse_fn(example):
    feature_description = {'atom_features': tf.io.FixedLenFeature([], tf.string), 'edge_index': tf.io.FixedLenFeature([], tf.string), 'num_nodes': tf.io.FixedLenFeature([], tf.string), 'token_ids': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.string),}; example = tf.io.parse_single_example(example, feature_description); atom_features = tf.io.parse_tensor(example['atom_features'], out_type=tf.float32); edge_index = tf.io.parse_tensor(example['edge_index'], out_type=tf.int32); num_nodes = tf.io.parse_tensor(example['num_nodes'], out_type=tf.int32); token_ids = tf.io.parse_tensor(example['token_ids'], out_type=tf.int32); label = tf.io.parse_tensor(example['label'], out_type=tf.float32); label = tf.reshape(label, [TASK_SPECS[DATASET_NAME]['output_units']]); return (atom_features, edge_index, num_nodes, token_ids), label

@tf.function
def prepare_batch_for_model(features, label):
    atom_features, edge_index, num_nodes, token_ids = features; atom_features_flat = tf.reshape(atom_features, (BATCH_SIZE * MAX_NODES, NUM_ATOM_FEATURES)); num_nodes_squeezed = tf.squeeze(num_nodes, axis=-1); node_offsets = tf.cumsum(num_nodes_squeezed, exclusive=True); is_real_edge_mask = edge_index[:, :, 0] >= 0; edge_batch_ids = tf.where(is_real_edge_mask)[:, 0]; edge_batch_ids = tf.cast(edge_batch_ids, dtype=tf.int32); edge_offsets = tf.gather(node_offsets, edge_batch_ids); real_edges = tf.boolean_mask(edge_index, is_real_edge_mask); global_edge_index = real_edges + tf.expand_dims(edge_offsets, axis=-1); padding_mask = (token_ids != 0); model_inputs = {'atom_features_input': atom_features_flat, 'edge_index_input': global_edge_index, 'num_nodes_input': num_nodes_squeezed, 'token_ids_input': token_ids, 'padding_mask_input': padding_mask}; return model_inputs, label

def create_dataset(file_pattern, should_shuffle=False):
    files = glob.glob(file_pattern);
    if not files: return None
    dataset = tf.data.TFRecordDataset(files, num_parallel_reads=tf.data.AUTOTUNE).map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
    if should_shuffle: dataset = dataset.shuffle(buffer_size=1024)
    dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=((tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), tf.TensorShape([None, 2]), tf.TensorShape([1]), tf.TensorShape([256])), tf.TensorShape([TASK_SPECS[DATASET_NAME]['output_units']])), drop_remainder=True)
    dataset = dataset.map(prepare_batch_for_model, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(tf.data.AUTOTUNE)

#  Subclassed keras.Model 
class GraspLinearProbeModel(keras.Model):
    def __init__(self, checkpoint_path, task_spec, **kwargs):
        super().__init__(**kwargs)
        self.gin_model = tf.saved_model.load(os.path.join(checkpoint_path, 'gin_encoder_best'))
        self.transformer_model = tf.saved_model.load(os.path.join(checkpoint_path, 'transformer_encoder_best'))
        self.gin_function = self.gin_model.signatures['serving_default']
        self.transformer_function = self.transformer_model.signatures['serving_default']
        output_activation = 'sigmoid' if task_spec['type'] == 'classification' else 'linear'
        self.head = layers.Dense(task_spec['output_units'], activation=output_activation)
        self.concat = layers.Concatenate()
        
    def call(self, data, training=False):
        atom_feats, edge_index, num_nodes, token_ids, padding_mask = (
            data['atom_features_input'], data['edge_index_input'], data['num_nodes_input'],
            data['token_ids_input'], data['padding_mask_input']
        )
        
        # Cast inputs for the GIN layer
        edge_index_float = tf.cast(edge_index, tf.float32)
        num_nodes_float = tf.cast(num_nodes, tf.float32)
        
        graph_embedding = self.gin_function(
            inputs=atom_feats, 
            inputs_1=edge_index_float, 
            inputs_2=num_nodes_float
        )
        
        # we casts inputs for the Transformer layer as well
        token_ids_float = tf.cast(token_ids, tf.float32)
        padding_mask_float = tf.cast(padding_mask, tf.float32)
        
        smiles_embedding = self.transformer_function(
            inputs=token_ids_float, 
            inputs_1=padding_mask_float
        )
        
        concatenated_embeddings = self.concat([graph_embedding['output_0'], smiles_embedding['output_0']])
        return self.head(concatenated_embeddings)

    def train_step(self, data):
        inputs, y_true = data
        with tf.GradientTape() as tape:
            y_pred = self(inputs, training=True)
            loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
            
        trainable_vars = self.head.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        inputs, y_true = data
        y_pred = self(inputs, training=False)
        self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}

# main execution 
def run_linear_probing():
    print(f" Starting Linear Probing for Dataset: {DATASET_NAME} ")
    task_spec = TASK_SPECS[DATASET_NAME]
    
    model = GraspLinearProbeModel(GRASP_CHECKPOINT_PATH, task_spec)
    optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(optimizer=optimizer, loss=task_spec['loss_fn'], metrics=task_spec['metrics'])
    
    train_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME}_train.tfrecord'), should_shuffle=True)
    valid_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME}_valid.tfrecord'))
    test_ds = create_dataset(os.path.join(MOLECULENET_PATH, f'{DATASET_NAME}_test.tfrecord'))
    if not train_ds: raise ValueError(f"Training TFRecord not found for {MOLECULENET_PATH}")

    print("\n Starting Training of the Linear Head ")
    model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds, verbose=1)
    
    print("\n Training Finished ")
    print("\n Final Performance on Unseen Test Data ")
    results = model.evaluate(test_ds)
    for name, value in results.items():
        print(f"  Final Test {name}: {value:.4f}")

run_linear_probing()

2025-07-03 19:32:09.511218: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751571129.753532      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751571129.822119      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-03 19:32:25.488026: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


--- 🚀 Starting Linear Probing for Dataset: bbbp ---

--- Starting Training of the Linear Head ---
Epoch 1/50


```
for metric in self.metrics:
    metric.update_state(y, y_pred)
```

  return self._compiled_metrics_update_state(


     25/Unknown [1m65s[0m 2s/step - auc: 0.4762 - loss: 0.2308



[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 3s/step - auc: 0.4730 - loss: 0.2409 - val_loss: 0.9231
Epoch 2/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 3s/step - auc: 0.3441 - loss: 0.5925 - val_loss: 0.8579
Epoch 3/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 3s/step - auc: 0.4482 - loss: 0.9049 - val_loss: 0.9943
Epoch 4/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 3s/step - auc: 0.4298 - loss: 0.7962 - val_loss: 0.6476
Epoch 5/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 3s/step - auc: 0.3764 - loss: 0.7431 - val_loss: 0.9998
Epoch 6/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 3s/step - auc: 0.3908 - loss: 0.6425 - val_loss: 0.9849
Epoch 7/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 3s/step - auc: 0.4855 - loss: 0.9764 - val_loss: 0.9720
Epoch 8/50
[1m25/25[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 3s/step - a

AttributeError: 'list' object has no attribute 'items'