# GRASP: Self-Supervised Molecular Representation Learning (Kaggle Edition)

This notebook contains the complete PyTorch implementation of the GRASP model, optimized for the Kaggle environment. 
It performs self-supervised pre-training by aligning molecular graph representations with their corresponding SMILES string representations using a free Kaggle GPU.

## 1. Install All Necessary Libraries
This cell will install all required packages. Kaggle environments are Linux-based and work well with these commands.

In [10]:
# --- Final, Definitive Installation Cell for Kaggle ---

# Step 1: Force uninstall all potentially conflicting libraries from the base environment.
!pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-cluster torch-spline-conv sentence-transformers transformers accelerate peft -y --quiet

# Step 2: Install a specific, known-stable combination of torch and torchvision for CUDA 12.1.
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 --quiet

# Step 3: Install SPECIFIC, older versions of transformers and accelerate that are compatible with torch 2.1.
!pip install "transformers==4.36.2" "accelerate==0.25.0" "timm>=0.9.2" --quiet

# Step 4: Install the remaining libraries.
!pip install torch_geometric rdkit-pypi pandas tqdm --quiet

# Step 5: Install the PyG dependencies that match the torch version we installed in Step 2.
!pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html --quiet

print("\nInstallation of a fully version-locked, compatible library set is complete.")
print("This setup should resolve all dependency conflicts. Please proceed with the rest of the notebook.")

[0m
Installation of a fully version-locked, compatible library set is complete.
This setup should resolve all dependency conflicts. Please proceed with the rest of the notebook.


## 2. Imports and Configuration
Import all libraries and define the key parameters for our training run.

In [11]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_add_pool

from transformers import AutoModel, AutoConfig, AutoTokenizer
from rdkit import Chem, rdBase
from tqdm import tqdm
import pandas as pd

# Suppress non-critical RDKit warnings
rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

# --- Main Configuration (Corrected for Memory) ---
# IMPORTANT: UPDATE THIS VARIABLE with the folder name of your uploaded dataset.
dataset_folder_name = 'pubchem-smiles-for-pretraining-txt'
SMILES_FILE_PATH = f'/kaggle/input/{dataset_folder_name}/pubchem_smiles_for_pretraining.txt'


TOKENIZER_NAME = 'seyonec/ChemBERTa-zinc-base-v1'

# --- Training Parameters ---
NUM_SAMPLES = 500000
# --- FIX 1: BATCH SIZE REDUCED FURTHER ---
BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 1e-4
TEMPERATURE = 0.07

# --- Model Parameters ---
# --- FIX 2: MODEL DIMENSIONS REDUCED ---
PROJECTION_DIM = 128
GRAPH_EMB_DIM = 128
GRAPH_LAYERS = 4

# --- System Parameters ---
NUM_WORKERS = 2 # Kaggle environments handle multiprocessing well.

# --- Verification ---
if not os.path.exists(SMILES_FILE_PATH):
    raise FileNotFoundError(
        f"Dataset file not found at '{SMILES_FILE_PATH}'. "
        "Please check the 'dataset_folder_name' variable and your uploaded file's name."
    )
print(f"Dataset found: {SMILES_FILE_PATH}")

Dataset found: /kaggle/input/pubchem-smiles-for-pretraining-txt/pubchem_smiles_for_pretraining.txt


## 3. Utility Functions (SMILES to Graph Conversion)
These functions handle the conversion of a SMILES string into a graph data structure with rich, one-hot encoded atom features.

In [12]:
# Atom feature definitions using one-hot encoding
ATOM_FEATURE_MAP = {
    'atomic_num': list(range(1, 119)),
    'degree': list(range(6)),
    'formal_charge': list(range(-2, 3)),
    'hybridization': [
        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'is_aromatic': [0, 1],
    'is_in_ring': [0, 1]
}

def one_hot_encode(value, choices):
    encoding = [0] * (len(choices) + 1) # +1 for "other" category
    try:
        index = choices.index(value)
    except ValueError:
        index = -1
    encoding[index] = 1
    return encoding

def get_atom_features(atom):
    features = []
    features += one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURE_MAP['atomic_num'])
    features += one_hot_encode(atom.GetDegree(), ATOM_FEATURE_MAP['degree'])
    features += one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURE_MAP['formal_charge'])
    features += one_hot_encode(atom.GetHybridization(), ATOM_FEATURE_MAP['hybridization'])
    features += one_hot_encode(int(atom.GetIsAromatic()), ATOM_FEATURE_MAP['is_aromatic'])
    features += one_hot_encode(int(atom.IsInRing()), ATOM_FEATURE_MAP['is_in_ring'])
    return torch.tensor(features, dtype=torch.float)

def get_num_node_features():
    return sum(len(choices) + 1 for choices in ATOM_FEATURE_MAP.values())

def smiles_to_graph_data(smiles_string: str):
    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None: return None

        atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
        x = torch.stack(atom_features)

        edge_indices = []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

        return Data(x=x, edge_index=edge_index, smiles=smiles_string)

    except Exception:
        return None

## 4. Data Pipeline (Dataset and Collator)
This section defines the robust data pipeline using PyTorch's `Dataset` and a custom `collator` function for efficient, dynamic batching.

In [13]:
# --- CORRECTED DATA PIPELINE CELL ---

class MoleculeDataset(Dataset):
    """
    Custom PyTorch Dataset.
    --- NEW: Now filters out molecules with too many atoms. ---
    """
    def __init__(self, file_path, tokenizer_name, num_samples=None, max_atoms=512):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_atoms = max_atoms
        
        print("Loading and filtering SMILES strings from file...")
        self.smiles_list = []
        with open(file_path, 'r') as f:
            # Use tqdm to show progress during the initial read and filter
            for i, line in enumerate(tqdm(f, desc="Reading file")):
                if num_samples is not None and len(self.smiles_list) >= num_samples:
                    break
                
                smiles = line.strip()
                # Pre-filter by string length for a quick check
                if len(smiles) > self.max_atoms * 2: # Heuristic to skip very long strings quickly
                    continue
                
                # Check actual atom count with RDKit
                mol = Chem.MolFromSmiles(smiles)
                if mol is not None and mol.GetNumAtoms() <= self.max_atoms:
                    self.smiles_list.append(smiles)

        print(f"Loaded {len(self.smiles_list)} molecules after filtering (max atoms = {self.max_atoms}).")

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        
        graph_data = smiles_to_graph_data(smiles)
        if graph_data is None:
            return None

        smiles_tokens = self.tokenizer(
            smiles,
            padding=False, 
            truncation=True,
            max_length=256,
            return_tensors='pt'
        )
        smiles_tokens = {key: val.squeeze(0) for key, val in smiles_tokens.items()}
        
        return graph_data, smiles_tokens

class CustomCollator:
    # This class remains the same.
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        batch = [item for item in batch if item is not None]
        if not batch:
            return None, None

        graphs, smiles_tokens_list = zip(*batch)
        graph_batch = Batch.from_data_list(graphs)

        smiles_padded = self.tokenizer.pad(
            {'input_ids': [s['input_ids'] for s in smiles_tokens_list],
             'attention_mask': [s['attention_mask'] for s in smiles_tokens_list]},
            return_tensors='pt',
            padding='longest'
        )
        
        return graph_batch, smiles_padded

## 5. Model Architecture
This section defines the three core PyTorch `nn.Module` classes: `GraphEncoder`, `SmilesEncoder`, and the main `GRASPModel`.

In [14]:
class GraphEncoder(nn.Module):
    def __init__(self, num_node_features, embedding_dim, num_layers, dropout):
        super(GraphEncoder, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = dropout
        self.num_layers = num_layers

        for i in range(num_layers):
            in_dim = int(num_node_features) if i == 0 else int(embedding_dim)
            hidden_dim = int(embedding_dim)

            mlp = nn.Sequential(
                nn.Linear(in_dim, 2 * hidden_dim),
                nn.ReLU(),
                nn.Linear(2 * hidden_dim, hidden_dim)
            )
            conv = GINConv(mlp, train_eps=True)
            self.convs.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

    def forward(self, x, edge_index, batch):
        h = x
        for i in range(self.num_layers):
            h = self.convs[i](h, edge_index)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        h_graph = global_add_pool(h, batch)
        return h_graph

class SmilesEncoder(nn.Module):
    def __init__(self, model_name='seyonec/ChemBERTa-zinc-base-v1', dropout=0.1):
        super(SmilesEncoder, self).__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.transformer = AutoModel.from_pretrained(model_name, config=config)
        self.smiles_embedding_dim = self.transformer.config.hidden_size

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :] 

class GRASPModel(nn.Module):
    def __init__(self, graph_emb_dim, graph_layers, projection_dim, dropout=0.1):
        super(GRASPModel, self).__init__()
        
        self.graph_encoder = GraphEncoder(
            num_node_features=get_num_node_features(),
            embedding_dim=graph_emb_dim,
            num_layers=graph_layers,
            dropout=dropout
        )
        
        self.smiles_encoder = SmilesEncoder(dropout=dropout)
        
        self.graph_projection = nn.Sequential(
            nn.Linear(graph_emb_dim, graph_emb_dim),
            nn.ReLU(),
            nn.Linear(graph_emb_dim, projection_dim)
        )
        
        self.smiles_projection = nn.Sequential(
            nn.Linear(self.smiles_encoder.smiles_embedding_dim, self.smiles_encoder.smiles_embedding_dim),
            nn.ReLU(),
            nn.Linear(self.smiles_encoder.smiles_embedding_dim, projection_dim)
        )

    def forward(self, graph_batch, smiles_batch):
        graph_embeddings = self.graph_encoder(
            x=graph_batch.x, edge_index=graph_batch.edge_index, batch=graph_batch.batch
        )
        smiles_embeddings = self.smiles_encoder(
            input_ids=smiles_batch['input_ids'], attention_mask=smiles_batch['attention_mask']
        )
        
        graph_proj = self.graph_projection(graph_embeddings)
        smiles_proj = self.smiles_projection(smiles_embeddings)
        
        return F.normalize(graph_proj, p=2, dim=1, eps=1e-8), F.normalize(smiles_proj, p=2, dim=1, eps=1e-8)

## 6. Pre-training Script
This section defines the InfoNCE loss function and the main training function that ties everything together.

In [15]:
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        labels = torch.arange(batch_size, device=z_i.device)
        sim_matrix = torch.matmul(z_i, z_j.T) / self.temperature
        loss_i_j = self.loss_fn(sim_matrix, labels)
        loss_j_i = self.loss_fn(sim_matrix.T, labels)
        return (loss_i_j + loss_j_i) / 2

def train_grasp():
    # --- Device Setup ---
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("No CUDA GPU found, using CPU.")

    # --- Data Pipeline ---
    dataset = MoleculeDataset(SMILES_FILE_PATH, TOKENIZER_NAME, num_samples=NUM_SAMPLES)
    collator = CustomCollator(tokenizer=dataset.tokenizer)
    data_loader = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator,
        num_workers=NUM_WORKERS, pin_memory=True if device.type == 'cuda' else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    # --- Model, Loss, and Optimizer ---
    model = GRASPModel(
        projection_dim=PROJECTION_DIM, 
        graph_emb_dim=GRAPH_EMB_DIM, 
        graph_layers=GRAPH_LAYERS
    ).to(device)
    criterion = InfoNCELoss(temperature=TEMPERATURE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(data_loader)*EPOCHS)

    # --- Training Loop ---
    print("\nStarting pre-training...")
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch_data in progress_bar:
            graph_batch, smiles_batch = batch_data
            if graph_batch is None: continue

            graph_batch = graph_batch.to(device)
            smiles_batch = {key: val.to(device) for key, val in smiles_batch.items()}
            
            optimizer.zero_grad()
            graph_proj, smiles_proj = model(graph_batch, smiles_batch)
            loss = criterion(graph_proj, smiles_proj)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(data_loader) if data_loader else 0
        print(f"\nEpoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")
        
        # Save model checkpoint to the output directory
        checkpoint_path = f"/kaggle/working/grasp_model_epoch_{epoch+1}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved to {checkpoint_path}")

    print("\nPre-training complete.")

## 7. Run Training
This cell executes the main training function.

In [17]:
train_grasp()

Using CUDA GPU: Tesla T4
Loading and filtering SMILES strings from file...


Reading file: 500537it [01:30, 5527.35it/s]


Loaded 500000 molecules after filtering (max atoms = 512).

Starting pre-training...


Epoch 1/5:   0%|          | 0/7813 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
2025-06-23 18:40:35.590377: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-23 18:40:35.590760: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750704035.625684     492 cuda_dnn.cc:8310] Unable to regis


Epoch 1/5 - Average Loss: 0.0786
Model checkpoint saved to /kaggle/working/grasp_model_epoch_1.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 2/5 - Average Loss: 0.0139
Model checkpoint saved to /kaggle/working/grasp_model_epoch_2.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 3/5 - Average Loss: 0.0076
Model checkpoint saved to /kaggle/working/grasp_model_epoch_3.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 4/5 - Average Loss: 0.0050
Model checkpoint saved to /kaggle/working/grasp_model_epoch_4.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 5/5 - Average Loss: 0.0039
Model checkpoint saved to /kaggle/working/grasp_model_epoch_5.pt

Pre-training complete.


## 8. Qualitative Evaluation
After training, this section validates that the model has learned effectively. It loads the final checkpoint, generates embeddings for a few test molecules, and computes their similarity. A high score on the diagonal of the similarity matrix indicates success.

In [19]:
def evaluate_model():
    print("\n Starting Post-Pretraining Qualitative Evaluation ")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Evaluation using device: {device}")
    
    final_checkpoint_path = f"/kaggle/working/grasp_model_epoch_{EPOCHS}.pt"
    if not os.path.exists(final_checkpoint_path):
        print(f"Error: Checkpoint file not found at {final_checkpoint_path}. Cannot evaluate.")
        return
        
    model = GRASPModel(
        projection_dim=PROJECTION_DIM, 
        graph_emb_dim=GRAPH_EMB_DIM, 
        graph_layers=GRAPH_LAYERS
    ).to(device)
    model.load_state_dict(torch.load(final_checkpoint_path, map_location=device))
    model.eval()
    print(f"Successfully loaded model from {final_checkpoint_path}")
    
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

    # Test Cases
    test_smiles = [
        "CCO",                      # Ethanol
        "c1ccccc1",                 # Benzene
        "CC(=O)Oc1ccccc1C(=O)O", # Aspirin
        "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
        "C"                         # Methane
    ]

    all_graph_embeddings = []
    all_smiles_embeddings = []
    # Use a copy of the list for display purposes, in case any SMILES fails
    valid_smiles_for_display = []

    print("\nGenerating embeddings for test molecules...")
    with torch.no_grad():
        for smiles in test_smiles:
            graph_data = smiles_to_graph_data(smiles)
            if graph_data is None:
                print(f"Warning: Could not process SMILES: {smiles}")
                continue
                   
            # edge case 
            if graph_data.edge_index.dim() == 1:
               
                graph_data.edge_index = torch.empty((2, 0), dtype=torch.long)
                    
            smiles_tokens = tokenizer(smiles, return_tensors='pt', padding=True)
            
            graph_batch = Batch.from_data_list([graph_data]).to(device)
            smiles_batch = {k: v.to(device) for k, v in smiles_tokens.items()}
            
            graph_proj, smiles_proj = model(graph_batch, smiles_batch)
            
            all_graph_embeddings.append(graph_proj)
            all_smiles_embeddings.append(smiles_proj)
            valid_smiles_for_display.append(smiles) # Add to list only if successful
    
    if not all_graph_embeddings:
        print("No embeddings were generated. Cannot create similarity matrix.")
        return
        
    graph_embeddings_tensor = torch.cat(all_graph_embeddings, dim=0)
    smiles_embeddings_tensor = torch.cat(all_smiles_embeddings, dim=0)

    # Similarity Matrix 
    print("\nCosine Similarity Matrix (Graph vs. SMILES) ")
    similarity_matrix = torch.matmul(graph_embeddings_tensor, smiles_embeddings_tensor.T).cpu().numpy()

    df = pd.DataFrame(similarity_matrix, index=valid_smiles_for_display, columns=valid_smiles_for_display)
    print(df.round(4))
    
    print("\nKey Observations ")
    print("High values on the diagonal (positive pairs), low values off-diagonal.")
    print("This indicates the model successfully learned to align representations.")


evaluate_model()


--- Starting Post-Pretraining Qualitative Evaluation ---
Evaluation using device: cuda
Successfully loaded model from /kaggle/working/grasp_model_epoch_5.pt

Generating embeddings for test molecules...

--- Cosine Similarity Matrix (Graph vs. SMILES) ---
                                 CCO  c1ccccc1  CC(=O)Oc1ccccc1C(=O)O  \
CCO                           0.9609    0.3435                -0.0449   
c1ccccc1                      0.3822    0.8747                -0.0821   
CC(=O)Oc1ccccc1C(=O)O        -0.0338   -0.0404                 0.9520   
CN1C=NC2=C1C(=O)N(C(=O)N2C)C  0.0961    0.2039                 0.0124   
C                             0.4803    0.3894                 0.2627   

                              CN1C=NC2=C1C(=O)N(C(=O)N2C)C       C  
CCO                                                 0.2041  0.5417  
c1ccccc1                                            0.3024  0.2932  
CC(=O)Oc1ccccc1C(=O)O                              -0.2467  0.2543  
CN1C=NC2=C1C(=O)N(C(=O)N2C)C 