In [1]:
!pip install rdkit
!pip install tensorflow-macos
!pip install tensorflow-metal
!pip install tqdm



In [2]:
import sys
print(sys.executable)

/Users/alihussain/miniforge3/envs/tf-metal/bin/python


# Importing Libraries

In [3]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from rdkit import Chem
from tqdm import tqdm

# Initialize Apple Silicon (M1/M2/M3) GPU (MPS)

In [4]:
print(f"TensorFlow version: {tf.__version__}")

if tf.config.experimental.list_physical_devices("GPU"):
    print("GPU is available.")
    device = "mps"
else:
    print("GPU not available, using CPU.")
    device = "cpu"

# Setting whatever the default device
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU'), 'GPU')
print(f"Using device: {device}")

TensorFlow version: 2.16.2
GPU is available.
Using device: mps


# Data Loading and Preprocessing

In [5]:
SMILES_FILE_PATH = 'pubchem_smiles_for_pretraining.txt'

if not os.path.exists(SMILES_FILE_PATH):
    raise FileNotFoundError(f"Dataset file not found at '{SMILES_FILE_PATH}'.")

max_nodes_found = 0
num_lines_to_check = 100000 

with open(SMILES_FILE_PATH, 'r') as f:
    for i, line in enumerate(f):
        if i >= num_lines_to_check:
            break
        smiles = line.strip()
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            num_nodes = mol.GetNumAtoms()
            if num_nodes > max_nodes_found:
                max_nodes_found = num_nodes
print(f"Maximum nodes found in the first {num_lines_to_check} molecules: {max_nodes_found}")

def load_smiles_data(file_path, num_samples=None):
    smiles_list = []
    with open(file_path, 'r') as f:
        for i, line in enumerate(f):
            if num_samples and i >= num_samples:
                break
            smiles_list.append(line.strip())
    print(f"Loaded {len(smiles_list)} SMILES strings.")
    return smiles_list


# For testing we will use 10k, 50k, 100k for quick tests basically.
all_smiles = load_smiles_data(SMILES_FILE_PATH, num_samples=100000) 

# We are Defining the featurization function for the dataset map operation
def featurize_smiles_and_graph(smiles_string):
    token_ids = tokenize_smiles(smiles_string.numpy().decode('utf-8'), char_to_idx, MAX_SMILES_LEN)
    mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string.numpy().decode('utf-8'))

    # Handling the cases where graph conversion fails suhc as invalid smiles
    if node_features is None:
        # we are creating dummy tensors with shapes that can be padded later
        dummy_node_features = tf.zeros((0, NUM_ATOM_FEATURES), dtype=tf.float32)
        dummy_edge_indices = tf.zeros((0, 2), dtype=tf.int32)
        dummy_num_nodes = tf.constant(0, dtype=tf.int32)
        dummy_num_edges = tf.constant(0, dtype=tf.int32) 
        return dummy_node_features, dummy_edge_indices, dummy_num_nodes, dummy_num_edges, token_ids, mask
    
    # We will ensure node_features has consistent shape by padding if necessary for batching
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    return (tf.constant(padded_node_features, dtype=tf.float32),
            tf.constant(edge_indices, dtype=tf.int32),
            tf.constant(num_nodes, dtype=tf.int32),
            tf.constant(num_edges, dtype=tf.int32),
            tf.constant(token_ids, dtype=tf.int32),
            tf.constant(mask, dtype=tf.bool))





Maximum nodes found in the first 100000 molecules: 419
Loaded 100000 SMILES strings.




# FurtherData Processing & Creating tf.data.Dataset

In [6]:
# SMILES Tokenization 
def build_smiles_vocab(smiles_list, max_vocab_size=None):
    all_chars = set()
    for smiles in smiles_list:
        for char in smiles:
            all_chars.add(char)
    vocab = sorted(list(all_chars))
    # Adding special tokens
    vocab = ['<pad>', '<unk>', '<cls>', '<eos>'] + vocab
    if max_vocab_size:
        vocab = vocab[:max_vocab_size]
    char_to_idx = {char: i for i, char in enumerate(vocab)}
    idx_to_char = {i: char for i, char in enumerate(vocab)}
    print(f"Built vocabulary of size: {len(vocab)}")
    return vocab, char_to_idx, idx_to_char

vocab, char_to_idx, idx_to_char = build_smiles_vocab(all_smiles)
VOCAB_SIZE = len(vocab)
MAX_SMILES_LEN = 256 # Max sequence length for Transformer.
MAX_NODES = max_nodes_found # Maximum number of nodes in any graph in a batch.

def tokenize_smiles(smiles, char_to_idx, max_len):
    """Here we convert a SMILES string to a sequence of token IDs."""
    tokens = list(smiles)
    indexed_tokens = [char_to_idx.get(char, char_to_idx['<unk>']) for char in tokens]
    
    if len(indexed_tokens) < max_len:
        padded_tokens = indexed_tokens + [char_to_idx['<pad>']] * (max_len - len(indexed_tokens))
    else:
        padded_tokens = indexed_tokens[:max_len]
    return np.array(padded_tokens, dtype=np.int32)

def create_smiles_mask(token_ids, pad_token_id):
    """This will create a boolean mask for padded tokens."""
    return tf.cast(token_ids == pad_token_id, tf.bool)

#  SMILES to TensorFlow Graph Conversion 
def atom_to_feature_vector(atom):
    """This will convert an RDKit atom to a feature vector."""
    features = []
    features.append(atom.GetAtomicNum())
    features.append(atom.GetDegree())
    features.append(int(atom.GetHybridization())) # Converting enum to int
    features.append(int(atom.GetIsAromatic()))
    features.append(atom.GetFormalCharge())
    return np.array(features, dtype=np.float32)

NUM_ATOM_FEATURES = len(atom_to_feature_vector(Chem.Atom(6))) 

def smiles_to_tf_graph(smiles_string):
    """
    Converts a SMILES string to TensorFlow graph components:
    node_features, edge_indices, num_nodes, and num_edges.
    """
    mol = Chem.MolFromSmiles(smiles_string)
    if mol is None:
        return None, None, None, None 

    node_features = [atom_to_feature_vector(atom) for atom in mol.GetAtoms()]
    if not node_features:
        return None, None, None, None 
    node_features = np.array(node_features, dtype=np.float32)
    num_nodes = len(node_features)

    edge_indices = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.append([i, j])
        edge_indices.append([j, i])

    if len(edge_indices) == 0:
        if num_nodes > 0: # Handling single atom molecules
            edge_indices_final = np.empty((0, 2), dtype=np.int32)
            num_edges_final = 0
        else: 
            return None, None, None, None
    else:
        edge_indices_final = np.array(edge_indices, dtype=np.int32)
        num_edges_final = len(edge_indices_final)

    return node_features, edge_indices_final, num_nodes, num_edges_final

dataset = tf.data.Dataset.from_tensor_slices(all_smiles)

dataset = dataset.map(lambda x: tf.py_function(
    featurize_smiles_and_graph,
    inp=[x],
    Tout=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.bool)
), num_parallel_calls=tf.data.AUTOTUNE)

dataset = dataset.filter(lambda node_feat, edge_idx, num_nodes, num_edges, token_ids, mask: num_nodes > 0)

BATCH_SIZE = 32 
GLOBAL_BATCH_SIZE = BATCH_SIZE
padded_shapes = (
    tf.TensorShape([MAX_NODES, NUM_ATOM_FEATURES]), # node_features
    tf.TensorShape([None, 2]),                     # edge_indices
    tf.TensorShape([]),                           # num_nodes
    tf.TensorShape([]),                           # num_edges
    tf.TensorShape([MAX_SMILES_LEN]),              # token_ids
    tf.TensorShape([MAX_SMILES_LEN])               # mask
)
padding_values = (
    tf.constant(0.0, dtype=tf.float32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(0, dtype=tf.int32),
    tf.constant(char_to_idx['<pad>'], dtype=tf.int32),
    tf.constant(True, dtype=tf.bool)
)

dataset = dataset.cache()
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=padded_shapes, padding_values=padding_values, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

Built vocabulary of size: 71


2025-06-22 23:07:27.507969: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-06-22 23:07:27.507992: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-06-22 23:07:27.507997: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-06-22 23:07:27.508015: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-06-22 23:07:27.508025: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


#  Model Architecture (TensorFlow/Keras) 

In [7]:
class GINLayer(layers.Layer):
    def __init__(self, output_dim, activation=None, **kwargs):
        super(GINLayer, self).__init__(**kwargs)
        self.output_dim = output_dim
        self.mlp = keras.Sequential([
            layers.Dense(output_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
        self.epsilon = self.add_weight(name='epsilon', shape=(),
                                       initializer=keras.initializers.Constant(0.0),
                                       trainable=True)
        self.activation = keras.activations.get(activation)

    def call(self, inputs):
        node_features, edge_indices_batch, num_nodes_batch = inputs
        
        edge_values = tf.ones(tf.shape(edge_indices_batch)[0], dtype=tf.float32)
        
        total_nodes_in_batch = tf.shape(node_features)[0]
        adj_shape = tf.cast([total_nodes_in_batch, total_nodes_in_batch], dtype=tf.int64)

        adj_sparse = tf.sparse.SparseTensor(indices=tf.cast(edge_indices_batch, tf.int64),
                                            values=edge_values,
                                            dense_shape=adj_shape)
        
        neighbor_sum = tf.sparse.sparse_dense_matmul(adj_sparse, node_features)

        combined_features = (1 + self.epsilon) * node_features + neighbor_sum
        output = self.mlp(combined_features)

        if self.activation is not None:
            output = self.activation(output)
        return output

class GINEncoder(keras.Model):
    def __init__(self, num_layers, hidden_dim, num_node_features, **kwargs): 
        super(GINEncoder, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.initial_mlp = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(hidden_dim)
        ])
        self.gin_layers = [GINLayer(hidden_dim, activation='relu') for _ in range(num_layers - 1)]
        self.gin_layers.append(GINLayer(hidden_dim, activation=None))
        self.bns = [layers.BatchNormalization() for _ in range(num_layers)]
    
    def call(self, inputs):
        node_features, edge_indices, num_nodes = inputs
        x = self.initial_mlp(node_features)
        for i in range(len(self.gin_layers)):
            x = self.gin_layers[i]((x, edge_indices, num_nodes))
            x = self.bns[i](x)
        
        batch_size = tf.shape(node_features)[0] // MAX_NODES
        x_reshaped = tf.reshape(x, (batch_size, MAX_NODES, self.hidden_dim))
        sequence_mask = tf.sequence_mask(num_nodes, maxlen=MAX_NODES, dtype=tf.float32)
        sequence_mask = tf.expand_dims(sequence_mask, axis=-1)
        masked_x = x_reshaped * sequence_mask
        graph_embedding = tf.reduce_sum(masked_x, axis=1)
        return graph_embedding

class TransformerEncoder(keras.Model):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_seq_len, dropout_rate=0.1, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.token_embedding = layers.Embedding(vocab_size, embed_dim)
        self.positional_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, max_seq_len, embed_dim),
            initializer="random_normal",
            trainable=True
        )
        self.encoder_layers = []
        for _ in range(num_layers):
            self.encoder_layers.append([
                layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout_rate),
                layers.LayerNormalization(epsilon=1e-6),
                layers.Dense(embed_dim * 4, activation="relu"),
                layers.Dense(embed_dim),
                layers.LayerNormalization(epsilon=1e-6),
            ])
        self.final_norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False):
        token_ids, padding_mask_bool = inputs
        x = self.token_embedding(token_ids)
        x = x + self.positional_embedding[:, :tf.shape(x)[1], :]
        attention_mask_additive = tf.cast(padding_mask_bool, dtype=tf.float32) * -1e9
        attention_mask_additive = tf.expand_dims(attention_mask_additive, axis=1)
        
        for attention, norm1, ff_dense1, ff_dense2, norm2 in self.encoder_layers:
            attn_output = attention(x, x, attention_mask=attention_mask_additive, training=training)
            x = norm1(x + attn_output)
            ff_output = ff_dense2(ff_dense1(x))
            x = norm2(x + ff_output)
        
        expanded_padding_mask = tf.cast(tf.expand_dims(padding_mask_bool, axis=-1), dtype=x.dtype)
        non_padded_mask = 1.0 - expanded_padding_mask
        x_masked = x * non_padded_mask
        sum_embeddings = tf.reduce_sum(x_masked, axis=1)
        non_padded_len = tf.reduce_sum(non_padded_mask, axis=1)
        smiles_embedding = sum_embeddings / (non_padded_len + 1e-9)
        return self.final_norm(smiles_embedding)

class ProjectionHead(keras.Model):
    def __init__(self, input_dim, output_dim, hidden_dim=256, **kwargs):
        super(ProjectionHead, self).__init__(**kwargs)
        self.net = keras.Sequential([
            layers.Dense(hidden_dim, activation='relu'),
            layers.Dense(output_dim)
        ])
    def call(self, x):
        return self.net(x)

class GRASPModel(keras.Model):
    def __init__(self, gin_config, transformer_config, projection_dim, **kwargs):
        super(GRASPModel, self).__init__(**kwargs)
        self.gin_encoder = GINEncoder(**gin_config)
        self.transformer_encoder = TransformerEncoder(**transformer_config)
        self.graph_projection_head = ProjectionHead(gin_config['hidden_dim'], projection_dim)
        self.smiles_projection_head = ProjectionHead(transformer_config['embed_dim'], projection_dim)
    
    def call(self, inputs, training=False):
        node_features_padded, edge_indices_padded, num_nodes, num_edges, token_ids, smiles_mask = inputs
        node_features_flat = tf.reshape(node_features_padded, (-1, tf.shape(node_features_padded)[2]))
        batch_size = tf.shape(node_features_padded)[0]
        
        edge_mask = tf.sequence_mask(num_edges, maxlen=tf.shape(edge_indices_padded)[1], dtype=tf.bool)
        valid_edge_indices = tf.cast(tf.boolean_mask(edge_indices_padded, edge_mask), dtype=tf.int32)
        
        node_offsets_for_edges = tf.range(batch_size) * MAX_NODES
        node_offsets_for_edges = tf.expand_dims(node_offsets_for_edges, axis=1)
        node_offsets_for_edges_expanded = tf.boolean_mask(tf.tile(node_offsets_for_edges, [1, tf.shape(edge_indices_padded)[1]]), edge_mask)
        node_offsets_for_edges_expanded = tf.expand_dims(node_offsets_for_edges_expanded, axis=-1)
        
        global_edge_indices_filtered = valid_edge_indices + tf.cast(node_offsets_for_edges_expanded, dtype=tf.int32)
        
        graph_embeddings_raw = self.gin_encoder((node_features_flat, global_edge_indices_filtered, num_nodes), training=training)
        graph_embeddings_projected = self.graph_projection_head(graph_embeddings_raw, training=training)
        
        smiles_embeddings_raw = self.transformer_encoder((token_ids, smiles_mask), training=training)
        smiles_embeddings_projected = self.smiles_projection_head(smiles_embeddings_raw, training=training)
        
        graph_embeddings_projected = tf.linalg.normalize(graph_embeddings_projected, axis=1)[0]
        smiles_embeddings_projected = tf.linalg.normalize(smiles_embeddings_projected, axis=1)[0]
        
        return graph_embeddings_projected, smiles_embeddings_projected

#  Contrastive Loss (InfoNCE) 

In [8]:
class InfoNCELoss(keras.losses.Loss):
    def __init__(self, temperature=0.07, name='info_nce_loss', **kwargs):
        super().__init__(name=name, **kwargs)
        self.temperature = temperature

    @tf.function
    def call(self, graph_embeddings, smiles_embeddings):
        logits = tf.matmul(graph_embeddings, smiles_embeddings, transpose_b=True) / self.temperature
        batch_size = tf.shape(logits)[0]
        labels = tf.eye(batch_size)
        loss_g_s = tf.keras.losses.categorical_crossentropy(labels, logits, from_logits=True)
        loss_s_g = tf.keras.losses.categorical_crossentropy(labels, tf.transpose(logits), from_logits=True)
        total_loss = (loss_g_s + loss_s_g) / 2
        return tf.reduce_mean(total_loss)

#  Training Loop 

In [9]:
PROJECTION_DIM = 128
HIDDEN_DIM_GIN = 256
NUM_GIN_LAYERS = 3
EMBED_DIM_TRANSFORMER = 256
NUM_TRANSFORMER_HEADS = 8
NUM_TRANSFORMER_LAYERS = 3

gin_config = {
    'num_layers': NUM_GIN_LAYERS,
    'hidden_dim': HIDDEN_DIM_GIN,
    'num_node_features': NUM_ATOM_FEATURES
}

transformer_config = {
    'vocab_size': VOCAB_SIZE,
    'embed_dim': EMBED_DIM_TRANSFORMER,
    'num_heads': NUM_TRANSFORMER_HEADS,
    'num_layers': NUM_TRANSFORMER_LAYERS,
    'max_seq_len': MAX_SMILES_LEN
}

# Instantiating model, loss, and optimizer
model = GRASPModel(gin_config, transformer_config, PROJECTION_DIM)
info_nce_loss = InfoNCELoss(temperature=0.07)
optimizer = keras.optimizers.Adam(learning_rate=1e-4)

# The train_step function performs a single gradient update.
@tf.function(reduce_retracing=True)
def train_step(inputs):
    with tf.GradientTape() as tape:
        graph_embeddings, smiles_embeddings = model(inputs, training=True)
        loss = info_nce_loss(graph_embeddings, smiles_embeddings)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

EPOCHS = 5
# here we  the number of steps per epoch for tqdm
steps_per_epoch_tqdm = len(all_smiles) // GLOBAL_BATCH_SIZE
steps_per_epoch_tqdm = max(1, steps_per_epoch_tqdm)

print(f"\nStarting pre-training for {EPOCHS} epochs...")
print(f"Dataset has {steps_per_epoch_tqdm} batches per epoch.")

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    total_loss = 0.0
    
    epoch_iterator = tqdm(dataset, desc=f"Epoch {epoch + 1} Training", total=steps_per_epoch_tqdm, leave=True)
    
    num_batches = 0 

    for i, batch_inputs in enumerate(epoch_iterator):
        loss = train_step(batch_inputs)
        total_loss += loss
        num_batches += 1
        
        epoch_iterator.set_postfix(loss=f"{loss.numpy():.4f}")
    
    if num_batches > 0: # Ensure num_batches is not zero to avoid division by zero
        avg_loss = total_loss / num_batches
        print(f"\nEpoch {epoch + 1} finished. Average Loss: {avg_loss:.4f}")
    else:
        print(f"\nEpoch {epoch + 1} finished. No batches processed (check dataset size/filtering).")

print("\nPre-training complete!")


Starting pre-training for 5 epochs...
Dataset has 3125 batches per epoch.

Epoch 1/5


2025-06-22 23:07:37.765524: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:450] ShuffleDatasetV3:4: Filling up shuffle buffer (this may take a while): 5453 of 10000
2025-06-22 23:07:44.957851: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:480] Shuffle buffer filled.
2025-06-22 23:07:47.327036: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
Epoch 1 Training: 100%|██████████| 3125/3125 [24:13<00:00,  2.35it/s, loss=1.1741]2025-06-22 23:31:40.806109: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 1 Training: 100%|██████████| 3125/3125 [24:13<00:00,  2.15it/s, loss=1.1741]



Epoch 1 finished. Average Loss: 1.2832

Epoch 2/5


Epoch 2 Training: 100%|██████████| 3125/3125 [21:54<00:00,  2.38it/s, loss=1.1583]2025-06-22 23:53:35.054577: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 2 Training: 100%|██████████| 3125/3125 [21:54<00:00,  2.38it/s, loss=1.1583]



Epoch 2 finished. Average Loss: 1.1770

Epoch 3/5


Epoch 3 Training: 100%|██████████| 3125/3125 [22:25<00:00,  2.32it/s, loss=1.1018]2025-06-23 00:16:00.389725: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 3 Training: 100%|██████████| 3125/3125 [22:25<00:00,  2.32it/s, loss=1.1018]



Epoch 3 finished. Average Loss: 1.2637

Epoch 4/5


Epoch 4 Training: 100%|██████████| 3125/3125 [22:48<00:00,  2.29it/s, loss=1.2610]2025-06-23 00:38:49.222983: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 4 Training: 100%|██████████| 3125/3125 [22:48<00:00,  2.28it/s, loss=1.2610]



Epoch 4 finished. Average Loss: 1.4348

Epoch 5/5


Epoch 5 Training: 100%|██████████| 3125/3125 [23:01<00:00,  2.22it/s, loss=1.3705]2025-06-23 01:01:51.081134: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Epoch 5 Training: 100%|██████████| 3125/3125 [23:01<00:00,  2.26it/s, loss=1.3705]


Epoch 5 finished. Average Loss: 1.5458

Pre-training complete!





#  Saving the encoders and the model 

In [10]:
tf.saved_model.save(model.gin_encoder, 'gin_encoder_pretrained')
tf.saved_model.save(model.transformer_encoder, 'transformer_encoder_pretrained')
print("Encoders saved.")

model.export('grasp_pretrained_model_tf_savedmodel')
print("Full Model saved to 'grasp_pretrained_model_tf_savedmodel' directory.")

INFO:tensorflow:Assets written to: gin_encoder_pretrained/assets


INFO:tensorflow:Assets written to: gin_encoder_pretrained/assets


INFO:tensorflow:Assets written to: transformer_encoder_pretrained/assets


INFO:tensorflow:Assets written to: transformer_encoder_pretrained/assets


Encoders saved.
INFO:tensorflow:Assets written to: grasp_pretrained_model_tf_savedmodel/assets


INFO:tensorflow:Assets written to: grasp_pretrained_model_tf_savedmodel/assets


Saved artifact at 'grasp_pretrained_model_tf_savedmodel'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): Tuple[TensorSpec(shape=(None, 419, 5), dtype=tf.float32, name=None), TensorSpec(shape=(None, 114, 2), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256), dtype=tf.float32, name=None), TensorSpec(shape=(None, 256), dtype=tf.float32, name=None)]
Output Type:
  Tuple[TensorSpec(shape=(None, 128), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128), dtype=tf.float32, name=None)]
Captures:
  6362182080: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6362182608: TensorSpec(shape=(), dtype=tf.resource, name=None)
  5980019296: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6366714992: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6000846176: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6362180320: T