In [None]:
import tensorflow as tf
import deepchem as dc
import numpy as np
from rdkit import Chem
from rdkit import rdBase
import os
from tqdm.auto import tqdm

# Suppress non-critical RDKit messages to keep the output clean
rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

## --- CONFIGURATION ---
# List of MoleculeNet datasets you want to process
DATASET_NAMES = ['Tox21', 'BBBP', 'ESOL']

# Parameters that MUST match your pre-training setup
MAX_SMILES_LEN = 256
MAX_NODES = 419
# Match the number of features your pre-training GIN model expects
NUM_ATOM_FEATURES = 5

# Output directory for the new TFRecord files. Using a new name is recommended.
OUTPUT_DIR = 'moleculenet_tfrecords_v2'
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# --- TOKENIZER (Must be identical to your pre-training notebook) ---
# We use the same hardcoded vocab to ensure consistency.
VOCAB = ['<pad>', '<unk>', '<cls>', '<eos>'] + sorted(list("CN=OP#SFClBI()[]@+1234567890HBrKkLlMmRrXxYyZzcbn"))
CHAR_TO_IDX = {char: i for i, char in enumerate(VOCAB)}
PAD_TOKEN_ID = CHAR_TO_IDX['<pad>']
print(f"Using a consistent vocabulary of size: {len(VOCAB)}")


def tokenize_smiles(smiles, max_len):
    """Converts a SMILES string to a padded sequence of token IDs."""
    tokens = list(smiles)
    indexed_tokens = [CHAR_TO_IDX.get(char, CHAR_TO_IDX['<unk>']) for char in tokens]
    # Truncate if longer than max_len
    if len(indexed_tokens) > max_len:
        indexed_tokens = indexed_tokens[:max_len]
    # Pad if shorter
    padded_tokens = indexed_tokens + [PAD_TOKEN_ID] * (max_len - len(indexed_tokens))
    return np.array(padded_tokens, dtype=np.int32)


# --- GRAPH FEATURIZER (From your pre-training notebook) ---
def atom_to_feature_vector(atom):
    """Generates a feature vector for a single atom."""
    return np.array([
        atom.GetAtomicNum(),
        atom.GetDegree(),
        int(atom.GetHybridization()),
        int(atom.GetIsAromatic()),
        atom.GetFormalCharge()
    ], dtype=np.float32)

def smiles_to_graph_and_tokens(smiles_string, max_nodes, max_len):
    """
    Converts a SMILES string to all necessary features for the GRASP model.
    Returns: A tuple containing padded atom features, edge index list, number of nodes, and token IDs.
    """
    mol = Chem.MolFromSmiles(smiles_string)
    if not mol or mol.GetNumAtoms() > max_nodes:
        return None # Skip molecules that are invalid or too large

    # --- Graph Features ---
    # 1. Atom Features
    atom_features = np.array([atom_to_feature_vector(atom) for atom in mol.GetAtoms()])
    num_nodes = len(atom_features)
    # Pad atom features to max_nodes
    padded_atom_features = np.zeros((max_nodes, NUM_ATOM_FEATURES), dtype=np.float32)
    padded_atom_features[:num_nodes] = atom_features

    # 2. Edge Index List (Correct format for your model)
    edge_indices = []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_indices.extend([[i, j], [j, i]]) # Add edges in both directions
    
    edge_index_array = np.array(edge_indices, dtype=np.int32) if edge_indices else np.zeros((0, 2), dtype=np.int32)

    # --- SMILES Features ---
    token_ids = tokenize_smiles(smiles_string, max_len)
    
    return padded_atom_features, edge_index_array, np.array([num_nodes], dtype=np.int32), token_ids


# --- TFRECORD SERIALIZATION ---
# Helper functions to create TensorFlow features
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))): # if value is a tensor
        value = value.numpy() # get its numpy value
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def create_tf_example(atom_features, edge_index, num_nodes, token_ids, label):
    """Creates a tf.train.Example proto from a single molecule's data."""
    feature = {
        'atom_features': _bytes_feature(tf.io.serialize_tensor(atom_features)),
        'edge_index': _bytes_feature(tf.io.serialize_tensor(edge_index)),
        'num_nodes': _bytes_feature(tf.io.serialize_tensor(num_nodes)),
        'token_ids': _bytes_feature(tf.io.serialize_tensor(token_ids)),
        'label': _bytes_feature(tf.io.serialize_tensor(label)),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:


# --- MAIN PROCESSING LOOP ---
def process_and_save_datasets():
    """Main function to load, process, and save all specified datasets."""
    for name in DATASET_NAMES:
        print(f"\n--- Processing dataset: {name} ---")
        
        # 1. Load data using DeepChem's MolNet
        if name == 'Tox21':
            tasks, datasets, transformers = dc.molnet.load_tox21(featurizer='Raw', splitter='scaffold')
        elif name == 'BBBP':
            tasks, datasets, transformers = dc.molnet.load_bbbp(featurizer='Raw', splitter='scaffold')
        elif name == 'ESOL':
            tasks, datasets, transformers = dc.molnet.load_esol(featurizer='Raw', splitter='random')
        else:
            continue
            
        train_dataset, valid_dataset, test_dataset = datasets
        
        # 2. Process each split (train, valid, test)
        for split_name, dataset in [('train', train_dataset), ('valid', valid_dataset), ('test', test_dataset)]:
            output_filename = os.path.join(OUTPUT_DIR, f'{name.lower()}_{split_name}.tfrecord')
            
            with tf.io.TFRecordWriter(output_filename) as writer:
                processed_count = 0
                # Use tqdm for a nice progress bar
                for smiles, label in tqdm(zip(dataset.ids, dataset.y), total=len(dataset), desc=f"  Writing {split_name}"):
                    
                    # Featurize the SMILES into graph and token representations
                    featurized_data = smiles_to_graph_and_tokens(smiles, MAX_NODES, MAX_SMILES_LEN)
                    if featurized_data is None:
                        continue # Skip if molecule is invalid or too large
                    
                    atom_f, edge_idx, num_n, token_ids = featurized_data
                    
                    # Ensure label is in the correct format (especially for multi-task like Tox21)
                    label_np = np.array(label, dtype=np.float32)
                    
                    # Create and write the TFRecord example
                    tf_example = create_tf_example(atom_f, edge_idx, num_n, token_ids, label_np)
                    writer.write(tf_example.SerializeToString())
                    processed_count += 1

            print(f"  ✅ Saved {processed_count} molecules to {output_filename}")

    print(f"\n--- All datasets processed successfully and saved in '{OUTPUT_DIR}'! ---")

In [None]:
# Run the entire preprocessing pipeline
process_and_save_datasets()