In [10]:
print("\n Starting Post-Pretraining Qualitative Evaluation ")

try:
    if 'model' not in locals() or not hasattr(model, 'gin_encoder') or not hasattr(model, 'transformer_encoder'):
        print("Loading encoders from disk...")
        loaded_gin_encoder = tf.saved_model.load('gin_encoder_pretrained')
        loaded_transformer_encoder = tf.saved_model.load('transformer_encoder_pretrained')
    else:
        print("Using encoders directly from trained model in memory.")
        loaded_gin_encoder = model.gin_encoder
        loaded_transformer_encoder = model.transformer_encoder
# Exiting if encoders cant be loaded
except Exception as e:
    print(f"Error loading encoders: {e}. Make sure they were saved correctly and paths are valid.")
    print("Skipping qualitative evaluation.")
    exit() 

def get_single_molecule_embeddings(smiles_string, gin_encoder, transformer_encoder):
# to get embeddings for a single SMILES string
    
    token_ids = tokenize_smiles(smiles_string, char_to_idx, MAX_SMILES_LEN)
    smiles_mask = create_smiles_mask(token_ids, char_to_idx['<pad>'])

    node_features, edge_indices, num_nodes, num_edges = smiles_to_tf_graph(smiles_string)

    if node_features is None or num_nodes == 0:
        print(f"Warning: Could not featurize SMILES '{smiles_string}'. Skipping.")
        return None, None
    
    # padding the node_features to MAX_NODES for consistent shape
    padded_node_features = tf.pad(node_features, [[0, MAX_NODES - num_nodes], [0, 0]])
    
    node_features_for_gin = tf.reshape(padded_node_features, (1 * MAX_NODES, NUM_ATOM_FEATURES)) 

    edge_indices_for_gin = tf.cast(edge_indices, dtype=tf.int32) 

    num_nodes_for_gin = tf.constant([num_nodes], dtype=tf.int32)
    
    token_ids_for_transformer = tf.expand_dims(token_ids, axis=0) 
    smiles_mask_for_transformer = tf.expand_dims(smiles_mask, axis=0) 
    
    graph_embedding = gin_encoder((node_features_for_gin, edge_indices_for_gin, num_nodes_for_gin), training=False)
    
    # smiles_embedding = transformer_encoder((token_ids_for_transformer, smiles_mask_for_transformer), training=False)
    smiles_embedding = transformer_encoder((token_ids_for_transformer, smiles_mask_for_transformer), training=False)
    
    graph_embedding_normalized = tf.linalg.normalize(graph_embedding, axis=1)[0]
    smiles_embedding_normalized = tf.linalg.normalize(smiles_embedding, axis=1)[0]

    return graph_embedding_normalized, smiles_embedding_normalized


#  Test Cases 
# Choosing a few diverse SMILES strings
test_smiles = [
    "CCO",      # Ethanol (small, simple alcohol)
    "c1ccccc1", # Benzene (aromatic ring)
    "O=C(Cl)c1ccccc1", # Benzoyl chloride (more complex, functional group)
    "CC(=O)Oc1ccccc1C(=O)O", # Aspirin (larger, common drug)
    "C" # Methane (single atom, check edge case)
]

# storing embeddings for all test SMILES
all_graph_embeddings = []
all_smiles_embeddings = []
valid_smiles_for_eval = []

for smiles in test_smiles:
    g_embed, s_embed = get_single_molecule_embeddings(smiles, loaded_gin_encoder, loaded_transformer_encoder)
    if g_embed is not None and s_embed is not None:
        all_graph_embeddings.append(g_embed)
        all_smiles_embeddings.append(s_embed)
        valid_smiles_for_eval.append(smiles)
    
if not valid_smiles_for_eval:
    print("No valid SMILES were processed for qualitative evaluation.")
else:
    all_graph_embeddings_tensor = tf.concat(all_graph_embeddings, axis=0) # (num_valid_smiles, embed_dim)
    all_smiles_embeddings_tensor = tf.concat(all_smiles_embeddings, axis=0) # (num_valid_smiles, embed_dim)

    print("\n Cosine Similarity Matrix (Graph vs. SMILES) ")
    # Calculate cosine similarity matrix: S_ij = sim(graph_i, smiles_j)
    similarity_matrix = tf.matmul(all_graph_embeddings_tensor, all_smiles_embeddings_tensor, transpose_b=True).numpy()

    print("Rows: Graph Embeddings of [SMILES]\nCols: SMILES Embeddings of [SMILES]")
    # printing column headers (SMILES)
    print(" " * 10 + "".join([f"{s[:8]:<10}" for s in valid_smiles_for_eval]))
    print("-" * (10 + len(valid_smiles_for_eval) * 10))

    for i, smiles_i in enumerate(valid_smiles_for_eval):
        row_str = f"{smiles_i[:8]:<10}" # Truncating SMILES for display
        for j in range(len(valid_smiles_for_eval)):
            row_str += f"{similarity_matrix[i, j]:<10.4f}"
        print(row_str)
    print("-" * (10 + len(valid_smiles_for_eval) * 10))


 Starting Post-Pretraining Qualitative Evaluation 
Using encoders directly from trained model in memory.

 Cosine Similarity Matrix (Graph vs. SMILES) 
Rows: Graph Embeddings of [SMILES]
Cols: SMILES Embeddings of [SMILES]
          CCO       c1ccccc1  O=C(Cl)c  CC(=O)Oc  C         
------------------------------------------------------------
CCO       -0.0233   -0.0285   -0.0415   -0.0385   -0.0555   
c1ccccc1  -0.0123   -0.0068   -0.0257   -0.0207   -0.0375   
O=C(Cl)c  -0.0238   -0.0172   -0.0354   -0.0293   -0.0501   
CC(=O)Oc  -0.0284   -0.0225   -0.0409   -0.0353   -0.0532   
C         -0.0215   -0.0135   -0.0130   -0.0155   -0.0526   
------------------------------------------------------------
