In [None]:
# preprocess_moleculenet.py

import os
import tensorflow as tf
import numpy as np
from rdkit import Chem
import deepchem as dc
from tqdm import tqdm

# Suppress non-critical RDKit warnings
from rdkit import rdBase
rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

# --- Global Constants (MUST match your pre-training and fine-tuning notebooks) ---
MAX_SMILES_LEN = 256
MAX_NODES = 419 # This should be the MAX_NODES used in your pre-training
NUM_ATOM_FEATURES = 5 # As defined by your atom_to_feature_vector

OUTPUT_TFRECORD_DIR = 'moleculenet_tfrecords' # Directory to save preprocessed MoleculeNet data

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_TFRECORD_DIR, exist_ok=True)

print(f"Output TFRecords will be saved to: {OUTPUT_TFRECORD_DIR}")
print(f"Configured MAX_NODES: {MAX_NODES}, NUM_ATOM_FEATURES: {NUM_ATOM_FEATURES}, MAX_SMILES_LEN: {MAX_SMILES_LEN}")


# --- Helper Functions (Copied from your notebooks for consistency) ---

# --- SMILES Tokenization ---
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

def tokenize_smiles(smiles, char_to_idx, max_len):
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    return tf.cast(token_ids == pad_token_id, tf.bool)


# --- SMILES to TensorFlow Graph Conversion ---
def atom_to_feature_vector(atom):
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization()))
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

def smiles_to_tf_graph(smiles_string):
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)

    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features:
        return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                np.zeros((0, 2), dtype=np.int32),
                0,
                0)
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])

    num_edges = len(edge_indices)
    if num_edges == 0:
        if num_nodes > 0:
            edge_indices_final = np.empty((0, 2), dtype=np.int32)
            num_edges_final = 0
        else:
            return (np.zeros((0, NUM_ATOM_FEATURES), dtype=np.float32),
                    np.zeros((0, 2), dtype=np.int32),
                    0,
                    0)
    else:
        edge_indices_final = np.array(edge_indices, dtype=np.int32)
        num_edges_final = len(edge_indices_final)

    return node_features, edge_indices_final, num_nodes, num_edges_final


# Define the featurization function for a single sample (returns flat tuple of NumPy arrays)
def featurize_smiles_and_graph_with_label(smiles_string, label):
    token_ids = tokenize_smiles(smiles_string, char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string)

    if num_nodes == 0:
        dummy_node_features = np.zeros((MAX_NODES, NUM_ATOM_FEATURES), dtype=np.float32)
        dummy_edge_indices = np.zeros((0, 2), dtype=np.int32)
        dummy_num_nodes = 0
        dummy_num_edges = 0
        return (dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, token_ids, mask, label)
    
    padded_node_features = np.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return (padded_node_features, edge_indices, num_nodes, num_edges, token_ids, mask, label)


# --- TFRecord Serialization Functions ---
def _bytes_feature(value):
    if isinstance(value, tf.Tensor): # Ensure it's a NumPy array before .tobytes()
        value = value.numpy()
    if isinstance(value, np.ndarray):
        value = value.tobytes()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_moleculenet_example(node_feat_padded, edge_idx, num_nodes, num_edges, token_ids, smiles_mask, label):
    feature = {
        'node_feat_padded': _bytes_feature(node_feat_padded),
        'edge_idx': _bytes_feature(edge_idx),
        'num_nodes': _int64_feature(num_nodes),
        'num_edges': _int64_feature(num_edges),
        'token_ids': _bytes_feature(token_ids),
        'smiles_mask': _bytes_feature(smiles_mask),
        'label': _bytes_feature(label) # Labels can be multi-dimensional (e.g., Tox21)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


# --- Main Preprocessing Loop for MoleculeNet Datasets ---
if __name__ == "__main__":
    print("Starting MoleculeNet preprocessing script...")

    # Build vocabulary from a dummy set of characters (or load from pre-training)
    # This ensures char_to_idx is defined for tokenization.
    dummy_smiles_for_vocab_build = ["C", "N", "O", "F", "P", "S", "Cl", "Br", "I", "c", "n", "=", "#", "(", ")", "[", "]", "@", "+", "-", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "H", "B", "b", "K", "k", "L", "l", "M", "m", "R", "r", "X", "x", "Y", "y", "Z", "z"] 
    global char_to_idx, VOCAB_SIZE # Ensure these are global for the script
    _, char_to_idx, _ = build_smiles_vocab(dummy_smiles_for_vocab_build)
    VOCAB_SIZE = len(char_to_idx) 
    print(f"Vocabulary built with {VOCAB_SIZE} tokens for MoleculeNet processing.")

    # --- Process and Save BBBP ---
    print("\nProcessing BBBP dataset...")
    bbbp_tasks, bbbp_datasets, _ = dc.molnet.load_bbbp(featurizer='Raw', splitter='scaffold')
    
    # Determine label dtype and shape for BBBP (binary classification, single label)
    bbbp_labels_dtype = tf.float32 if bbbp_datasets[0].y.dtype == np.bool_ or bbbp_datasets[0].y.dtype == np.int_ else bbbp_datasets[0].y.dtype
    bbbp_label_shape = bbbp_datasets[0].y.shape[1:] if bbbp_datasets[0].y.ndim > 1 else ()

    for dataset_split, name in zip(bbbp_datasets, ['train', 'valid', 'test']):
        tfrecord_path = os.path.join(OUTPUT_TFRECORD_DIR, f'bbbp_{name}.tfrecord')
        num_samples_processed = 0
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for i in tqdm(range(len(dataset_split.X)), desc=f"Featurizing BBBP {name}"):
                smiles_str = dataset_split.X[i]
                label = dataset_split.y[i]
                
                # Call featurization, it returns a flat tuple of NumPy arrays
                processed_data_flat = featurize_smiles_and_graph_with_label(smiles_str, label)
                
                if processed_data_flat[2] > 0: # num_nodes > 0
                    # Ensure label is correctly cast to the determined labels_dtype for serialization
                    label_tensor = tf.constant(processed_data_flat[6], dtype=bbbp_labels_dtype)
                    
                    example = serialize_moleculenet_example(*processed_data_flat[:6], label=label_tensor)
                    writer.write(example.SerializeToString())
                    num_samples_processed += 1
        print(f"Saved {num_samples_processed} samples to {tfrecord_path}")

    # --- Process and Save Tox21 ---
    print("\nProcessing Tox21 dataset...")
    tox21_tasks, tox21_datasets, _ = dc.molnet.load_tox21(featurizer='Raw', splitter='scaffold')

    # Determine label dtype and shape for Tox21 (multi-label classification)
    tox21_labels_dtype = tf.float32 if tox21_datasets[0].y.dtype == np.bool_ or tox21_datasets[0].y.dtype == np.int_ else tox21_datasets[0].y.dtype
    tox21_label_shape = tox21_datasets[0].y.shape[1:] if tox21_datasets[0].y.ndim > 1 else ()

    for dataset_split, name in zip(tox21_datasets, ['train', 'valid', 'test']):
        tfrecord_path = os.path.join(OUTPUT_TFRECORD_DIR, f'tox21_{name}.tfrecord')
        num_samples_processed = 0
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for i in tqdm(range(len(dataset_split.X)), desc=f"Featurizing Tox21 {name}"):
                smiles_str = dataset_split.X[i]
                label = dataset_split.y[i]
                
                processed_data_flat = featurize_smiles_and_graph_with_label(smiles_str, label)
                
                if processed_data_flat[2] > 0: # num_nodes > 0
                    label_tensor = tf.constant(processed_data_flat[6], dtype=tox21_labels_dtype)
                    example = serialize_moleculenet_example(*processed_data_flat[:6], label=label_tensor)
                    writer.write(example.SerializeToString())
                    num_samples_processed += 1
        print(f"Saved {num_samples_processed} samples to {tfrecord_path}")

    # --- Process and Save ESOL ---
    print("\nProcessing ESOL dataset...")
    esol_tasks, esol_datasets, _ = dc.molnet.load_esol(featurizer='Raw', splitter='scaffold')

    # Determine label dtype and shape for ESOL (regression, single label)
    esol_labels_dtype = tf.float32 if esol_datasets[0].y.dtype == np.bool_ or esol_datasets[0].y.dtype == np.int_ else esol_datasets[0].y.dtype
    esol_label_shape = esol_datasets[0].y.shape[1:] if esol_datasets[0].y.ndim > 1 else ()

    for dataset_split, name in zip(esol_datasets, ['train', 'valid', 'test']):
        tfrecord_path = os.path.join(OUTPUT_TFRECORD_DIR, f'esol_{name}.tfrecord')
        num_samples_processed = 0
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for i in tqdm(range(len(dataset_split.X)), desc=f"Featurizing ESOL {name}"):
                smiles_str = dataset_split.X[i]
                label = dataset_split.y[i]
                
                processed_data_flat = featurize_smiles_and_graph_with_label(smiles_str, label)
                
                if processed_data_flat[2] > 0: # num_nodes > 0
                    label_tensor = tf.constant(processed_data_flat[6], dtype=esol_labels_dtype)
                    example = serialize_moleculenet_example(*processed_data_flat[:6], label=label_tensor)
                    writer.write(example.SerializeToString())
                    num_samples_processed += 1
        print(f"Saved {num_samples_processed} samples to {tfrecord_path}")

    print("\nMoleculeNet preprocessing script finished.")
