# GRASP: Self-Supervised Molecular Representation Learning (Kaggle Edition)

This notebook contains the complete PyTorch implementation of the GRASP model, optimized for the Kaggle environment. It performs self-supervised pre-training by aligning molecular graph representations with their corresponding SMILES string representations using a free Kaggle GPU.

**Instructions:**
1.  Ensure the **GPU is enabled** as the accelerator in the notebook settings.
2.  Use the `+ Add data` button to attach your uploaded PubChem SMILES dataset.
3.  **Update the `dataset_folder_name` variable** in Cell #3 to match your dataset's folder name.

## 1. Install All Necessary Libraries
This cell will install all required packages.

In [10]:

!pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-cluster torch-spline-conv sentence-transformers transformers accelerate peft -y --quiet

#  torchvision for cuda
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 --quiet

# older versions 
!pip install "transformers==4.36.2" "accelerate==0.25.0" "timm>=0.9.2" --quiet

!pip install torch_geometric rdkit-pypi pandas tqdm --quiet

!pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html --quiet

print("\nInstallation of a fully version-locked, compatible library set is complete.")


[0m
Installation of a fully version-locked, compatible library set is complete.
This setup should resolve all dependency conflicts. Please proceed with the rest of the notebook.


## 2. Imports and Configuration
Importing all libraries and define the key parameters for our training run.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_add_pool

from transformers import AutoModel, AutoConfig, AutoTokenizer
from rdkit import Chem, rdBase
from tqdm import tqdm
import pandas as pd

rdBase.DisableLog('rdApp.warning')
rdBase.DisableLog('rdApp.error')

dataset_folder_name = 'pubchem-smiles-for-pretraining-txt'
SMILES_FILE_PATH = f'/kaggle/input/{dataset_folder_name}/pubchem_smiles_for_pretraining.txt'


TOKENIZER_NAME = 'seyonec/ChemBERTa-zinc-base-v1'


NUM_SAMPLES = 500000

BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 1e-4
TEMPERATURE = 0.07

PROJECTION_DIM = 128
GRAPH_EMB_DIM = 128
GRAPH_LAYERS = 4


NUM_WORKERS = 2 

# Verify
if not os.path.exists(SMILES_FILE_PATH):
    raise FileNotFoundError(
        f"Dataset file not found at '{SMILES_FILE_PATH}'. "
        "Please check the 'dataset_folder_name' variable and your uploaded file's name."
    )
print(f"Dataset found: {SMILES_FILE_PATH}")

## 3. Utility Functions (SMILES to Graph Conversion)
These functions handle the conversion of a SMILES string into a graph data structure with rich, one-hot encoded atom features.

In [12]:
# one hot encoding
ATOM_FEATURE_MAP = {
    'atomic_num': list(range(1, 119)),
    'degree': list(range(6)),
    'formal_charge': list(range(-2, 3)),
    'hybridization': [
        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'is_aromatic': [0, 1],
    'is_in_ring': [0, 1]
}

def one_hot_encode(value, choices):
    encoding = [0] * (len(choices) + 1) 
    try:
        index = choices.index(value)
    except ValueError:
        index = -1
    encoding[index] = 1
    return encoding

def get_atom_features(atom):
    features = []
    features += one_hot_encode(atom.GetAtomicNum(), ATOM_FEATURE_MAP['atomic_num'])
    features += one_hot_encode(atom.GetDegree(), ATOM_FEATURE_MAP['degree'])
    features += one_hot_encode(atom.GetFormalCharge(), ATOM_FEATURE_MAP['formal_charge'])
    features += one_hot_encode(atom.GetHybridization(), ATOM_FEATURE_MAP['hybridization'])
    features += one_hot_encode(int(atom.GetIsAromatic()), ATOM_FEATURE_MAP['is_aromatic'])
    features += one_hot_encode(int(atom.IsInRing()), ATOM_FEATURE_MAP['is_in_ring'])
    return torch.tensor(features, dtype=torch.float)

def get_num_node_features():
    return sum(len(choices) + 1 for choices in ATOM_FEATURE_MAP.values())

def smiles_to_graph_data(smiles_string: str):
    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None: return None

        atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
        x = torch.stack(atom_features)

        edge_indices = []
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()

        return Data(x=x, edge_index=edge_index, smiles=smiles_string)

    except Exception:
        return None

## 4. Data Pipeline (Dataset and Collator)
This section defines the robust data pipeline using PyTorch's `Dataset` and a custom `collator` function for efficient, dynamic batching.

In [13]:


class MoleculeDataset(Dataset):
    """
    Custom PyTorch Dataset.
    --- we filter out molecules with too many atoms. ---
    """
    def __init__(self, file_path, tokenizer_name, num_samples=None, max_atoms=512):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_atoms = max_atoms
        
        print("Loading and filtering SMILES strings from file")
        self.smiles_list = []
        with open(file_path, 'r') as f:
            # Using tqdm to show progress 
            for i, line in enumerate(tqdm(f, desc="Reading file")):
                if num_samples is not None and len(self.smiles_list) >= num_samples:
                    break
                
                smiles = line.strip()
            
                if len(smiles) > self.max_atoms * 2: # skiping very long strings quickly
                    continue
                
                # Checking actual atom count with rdkit
                mol = Chem.MolFromSmiles(smiles)
                if mol is not None and mol.GetNumAtoms() <= self.max_atoms:
                    self.smiles_list.append(smiles)

        print(f"Loaded {len(self.smiles_list)} molecules after filtering (max atoms = {self.max_atoms}).")

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        
        graph_data = smiles_to_graph_data(smiles)
        if graph_data is None:
            return None

        smiles_tokens = self.tokenizer(
            smiles,
            padding=False, 
            truncation=True,
            max_length=256,
            return_tensors='pt'
        )
        smiles_tokens = {key: val.squeeze(0) for key, val in smiles_tokens.items()}
        
        return graph_data, smiles_tokens

class CustomCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        batch = [item for item in batch if item is not None]
        if not batch:
            return None, None

        graphs, smiles_tokens_list = zip(*batch)
        graph_batch = Batch.from_data_list(graphs)

        smiles_padded = self.tokenizer.pad(
            {'input_ids': [s['input_ids'] for s in smiles_tokens_list],
             'attention_mask': [s['attention_mask'] for s in smiles_tokens_list]},
            return_tensors='pt',
            padding='longest'
        )
        
        return graph_batch, smiles_padded

## 5. Model Architecture
This section defines the three core PyTorch `nn.Module` classes: `GraphEncoder`, `SmilesEncoder`, and the main `GRASPModel`.

In [14]:
class GraphEncoder(nn.Module):
    def __init__(self, num_node_features, embedding_dim, num_layers, dropout):
        super(GraphEncoder, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = dropout
        self.num_layers = num_layers

        for i in range(num_layers):
            in_dim = int(num_node_features) if i == 0 else int(embedding_dim)
            hidden_dim = int(embedding_dim)

            mlp = nn.Sequential(
                nn.Linear(in_dim, 2 * hidden_dim),
                nn.ReLU(),
                nn.Linear(2 * hidden_dim, hidden_dim)
            )
            conv = GINConv(mlp, train_eps=True)
            self.convs.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

    def forward(self, x, edge_index, batch):
        h = x
        for i in range(self.num_layers):
            h = self.convs[i](h, edge_index)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        
        h_graph = global_add_pool(h, batch)
        return h_graph

class SmilesEncoder(nn.Module):
    def __init__(self, model_name='seyonec/ChemBERTa-zinc-base-v1', dropout=0.1):
        super(SmilesEncoder, self).__init__()
        config = AutoConfig.from_pretrained(model_name)
        self.transformer = AutoModel.from_pretrained(model_name, config=config)
        self.smiles_embedding_dim = self.transformer.config.hidden_size

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :] 

class GRASPModel(nn.Module):
    def __init__(self, graph_emb_dim, graph_layers, projection_dim, dropout=0.1):
        super(GRASPModel, self).__init__()
        
        self.graph_encoder = GraphEncoder(
            num_node_features=get_num_node_features(),
            embedding_dim=graph_emb_dim,
            num_layers=graph_layers,
            dropout=dropout
        )
        
        self.smiles_encoder = SmilesEncoder(dropout=dropout)
        
        self.graph_projection = nn.Sequential(
            nn.Linear(graph_emb_dim, graph_emb_dim),
            nn.ReLU(),
            nn.Linear(graph_emb_dim, projection_dim)
        )
        
        self.smiles_projection = nn.Sequential(
            nn.Linear(self.smiles_encoder.smiles_embedding_dim, self.smiles_encoder.smiles_embedding_dim),
            nn.ReLU(),
            nn.Linear(self.smiles_encoder.smiles_embedding_dim, projection_dim)
        )

    def forward(self, graph_batch, smiles_batch):
        graph_embeddings = self.graph_encoder(
            x=graph_batch.x, edge_index=graph_batch.edge_index, batch=graph_batch.batch
        )
        smiles_embeddings = self.smiles_encoder(
            input_ids=smiles_batch['input_ids'], attention_mask=smiles_batch['attention_mask']
        )
        
        graph_proj = self.graph_projection(graph_embeddings)
        smiles_proj = self.smiles_projection(smiles_embeddings)
        
        return F.normalize(graph_proj, p=2, dim=1, eps=1e-8), F.normalize(smiles_proj, p=2, dim=1, eps=1e-8)

## 6. Pre-training Script
This section defines the InfoNCE loss function and the main training function that ties everything together.

In [15]:
class InfoNCELoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(InfoNCELoss, self).__init__()
        self.temperature = temperature
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        labels = torch.arange(batch_size, device=z_i.device)
        sim_matrix = torch.matmul(z_i, z_j.T) / self.temperature
        loss_i_j = self.loss_fn(sim_matrix, labels)
        loss_j_i = self.loss_fn(sim_matrix.T, labels)
        return (loss_i_j + loss_j_i) / 2

def train_grasp():

    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("No CUDA GPU found, using CPU.")

    # Pipeline
    dataset = MoleculeDataset(SMILES_FILE_PATH, TOKENIZER_NAME, num_samples=NUM_SAMPLES)
    collator = CustomCollator(tokenizer=dataset.tokenizer)
    data_loader = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator,
        num_workers=NUM_WORKERS, pin_memory=True if device.type == 'cuda' else False,
        persistent_workers=True if NUM_WORKERS > 0 else False
    )

    # Model, Loss, and Optimizer
    model = GRASPModel(
        projection_dim=PROJECTION_DIM, 
        graph_emb_dim=GRAPH_EMB_DIM, 
        graph_layers=GRAPH_LAYERS
    ).to(device)
    criterion = InfoNCELoss(temperature=TEMPERATURE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(data_loader)*EPOCHS)

    # pre-training
    print("\nStarting pre-training")
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        
        progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch_data in progress_bar:
            graph_batch, smiles_batch = batch_data
            if graph_batch is None: continue

            graph_batch = graph_batch.to(device)
            smiles_batch = {key: val.to(device) for key, val in smiles_batch.items()}
            
            optimizer.zero_grad()
            graph_proj, smiles_proj = model(graph_batch, smiles_batch)
            loss = criterion(graph_proj, smiles_proj)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_loss = total_loss / len(data_loader) if data_loader else 0
        print(f"\nEpoch {epoch+1}/{EPOCHS} - Average Loss: {avg_loss:.4f}")
        
        checkpoint_path = f"/kaggle/working/grasp_model_epoch_{epoch+1}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved to {checkpoint_path}")

    print("\nPre-training complete.")

## 7. Run Training
This cell executes the main training function.

In [17]:
train_grasp()

Using CUDA GPU: Tesla T4
Loading and filtering SMILES strings from file...


Reading file: 500537it [01:30, 5527.35it/s]


Loaded 500000 molecules after filtering (max atoms = 512).

Starting pre-training...


Epoch 1/5:   0%|          | 0/7813 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
2025-06-23 18:40:35.590377: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-23 18:40:35.590760: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750704035.625684     492 cuda_dnn.cc:8310] Unable to regis


Epoch 1/5 - Average Loss: 0.0786
Model checkpoint saved to /kaggle/working/grasp_model_epoch_1.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 2/5 - Average Loss: 0.0139
Model checkpoint saved to /kaggle/working/grasp_model_epoch_2.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 3/5 - Average Loss: 0.0076
Model checkpoint saved to /kaggle/working/grasp_model_epoch_3.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 4/5 - Average Loss: 0.0050
Model checkpoint saved to /kaggle/working/grasp_model_epoch_4.pt


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 


Epoch 5/5 - Average Loss: 0.0039
Model checkpoint saved to /kaggle/working/grasp_model_epoch_5.pt

Pre-training complete.


## 8. Qualitative Evaluation
After training, this section validates that the model has learned effectively. It loads the final checkpoint, generates embeddings for a few test molecules, and computes their similarity. A high score on the diagonal of the similarity matrix indicates success.

In [19]:
def evaluate_model():
    print("\n Starting Post-Pretraining Qualitative Evaluation ")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Evaluation using device: {device}")
    
    final_checkpoint_path = f"/kaggle/working/grasp_model_epoch_{EPOCHS}.pt"
    if not os.path.exists(final_checkpoint_path):
        print(f"Error: Checkpoint file not found at {final_checkpoint_path}. Cannot evaluate.")
        return
        
    model = GRASPModel(
        projection_dim=PROJECTION_DIM, 
        graph_emb_dim=GRAPH_EMB_DIM, 
        graph_layers=GRAPH_LAYERS
    ).to(device)
    model.load_state_dict(torch.load(final_checkpoint_path, map_location=device))
    model.eval()
    print(f"Successfully loaded model from {final_checkpoint_path}")
    
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

    # Test Cases
    test_smiles = [
        "CCO",                      # Ethanol
        "c1ccccc1",                 # Benzene
        "CC(=O)Oc1ccccc1C(=O)O", # Aspirin
        "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
        "C"                         # Methane
    ]

    all_graph_embeddings = []
    all_smiles_embeddings = []
    valid_smiles_for_display = []

    print("\nGenerating embeddings for test molecules...")
    with torch.no_grad():
        for smiles in test_smiles:
            graph_data = smiles_to_graph_data(smiles)
            if graph_data is None:
                print(f"Warning: Could not process SMILES: {smiles}")
                continue
                   
            # edge case 
            if graph_data.edge_index.dim() == 1:
               
                graph_data.edge_index = torch.empty((2, 0), dtype=torch.long)
                    
            smiles_tokens = tokenizer(smiles, return_tensors='pt', padding=True)
            
            graph_batch = Batch.from_data_list([graph_data]).to(device)
            smiles_batch = {k: v.to(device) for k, v in smiles_tokens.items()}
            
            graph_proj, smiles_proj = model(graph_batch, smiles_batch)
            
            all_graph_embeddings.append(graph_proj)
            all_smiles_embeddings.append(smiles_proj)
            valid_smiles_for_display.append(smiles) # Add to list only if successful
    
    if not all_graph_embeddings:
        print("No embeddings were generated. Cannot create similarity matrix.")
        return
        
    graph_embeddings_tensor = torch.cat(all_graph_embeddings, dim=0)
    smiles_embeddings_tensor = torch.cat(all_smiles_embeddings, dim=0)

    # Similarity Matrix 
    print("\nCosine Similarity Matrix (Graph vs. SMILES) ")
    similarity_matrix = torch.matmul(graph_embeddings_tensor, smiles_embeddings_tensor.T).cpu().numpy()

    df = pd.DataFrame(similarity_matrix, index=valid_smiles_for_display, columns=valid_smiles_for_display)
    print(df.round(4))
    
    print("\nKey Observations ")
    print("High values on the diagonal (positive pairs), low values off-diagonal.")
    print("This indicates the model successfully learned to align representations.")


evaluate_model()


--- Starting Post-Pretraining Qualitative Evaluation ---
Evaluation using device: cuda
Successfully loaded model from /kaggle/working/grasp_model_epoch_5.pt

Generating embeddings for test molecules...

--- Cosine Similarity Matrix (Graph vs. SMILES) ---
                                 CCO  c1ccccc1  CC(=O)Oc1ccccc1C(=O)O  \
CCO                           0.9609    0.3435                -0.0449   
c1ccccc1                      0.3822    0.8747                -0.0821   
CC(=O)Oc1ccccc1C(=O)O        -0.0338   -0.0404                 0.9520   
CN1C=NC2=C1C(=O)N(C(=O)N2C)C  0.0961    0.2039                 0.0124   
C                             0.4803    0.3894                 0.2627   

                              CN1C=NC2=C1C(=O)N(C(=O)N2C)C       C  
CCO                                                 0.2041  0.5417  
c1ccccc1                                            0.3024  0.2932  
CC(=O)Oc1ccccc1C(=O)O                              -0.2467  0.2543  
CN1C=NC2=C1C(=O)N(C(=O)N2C)C 

In [2]:
# RUN THIS CELL FIRST!
!pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-cluster torch-spline-conv -y --quiet

# Install a compatible CUDA version of PyTorch
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 --quiet

# Install core libraries
!pip install "transformers==4.36.2" "accelerate==0.25.0" "deepchem" --quiet

# Install PyTorch Geometric and its dependencies
!pip install torch_geometric rdkit-pypi pandas tqdm --quiet
!pip install pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html --quiet

print("\n✅ All required libraries have been installed successfully!")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 GB[0m [31m548.7 kB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m48.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.2/89.2 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m69.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 

In [2]:
# ===================================================================
# CELL 1: COMPLETE & ROBUST INSTALLATION
# ===================================================================
# This command installs all necessary libraries at once.
# It lets pip resolve the dependencies for the latest Kaggle environment.
!pip install deepchem rdkit-pypi torch_geometric --quiet

print("✅ Installation complete. All required libraries (DeepChem, RDKit, PyTorch Geometric) are installed.")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m58.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m56.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.9/34.9 MB[0m [31m50.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h✅ Installation complete. All required libraries (DeepChem, RDKit, PyTorch Geometric) are installed.


In [3]:
# ===================================================================
# CELL 2: ALL IMPORTS AND CONFIGURATION
# ===================================================================

# --- All Imports ---
import os, gc, numpy as np, pandas as pd
from tqdm import tqdm
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv, global_add_pool
from transformers import AutoModel, AutoConfig
from rdkit import Chem, rdBase
import deepchem as dc
from sklearn.metrics import roc_auc_score, mean_squared_error

# --- Environment Setup ---
rdBase.DisableLog('rdApp.warning'); rdBase.DisableLog('rdApp.error')
print(f"✅ Setup Complete. Using DeepChem version: {dc.__version__}")

# --- Configuration ---
PROJECTION_DIM, GRAPH_EMB_DIM, GRAPH_LAYERS, NUM_WORKERS = 128, 128, 4, 2
PRETRAINED_CHECKPOINT_PATH = "/kaggle/input/baseline/pytorch/default/1/Saved-Model-and-Encoders(Kaggle)/grasp_model_epoch_5.pt"
FT_WARMUP_EPOCHS, FT_MAIN_EPOCHS, FT_BATCH_SIZE = 5, 25, 32
ENCODER_LEARNING_RATE, HEAD_LEARNING_RATE = 1e-5, 1e-4

if not os.path.exists(PRETRAINED_CHECKPOINT_PATH): raise FileNotFoundError(f"Checkpoint not found at {PRETRAINED_CHECKPOINT_PATH}")
print(f"Found checkpoint: {PRETRAINED_CHECKPOINT_PATH}")

2025-07-05 08:09:04.071839: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751702944.251847      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751702944.303117      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


✅ Setup Complete. Using DeepChem version: 2.8.0
Found checkpoint: /kaggle/input/baseline/pytorch/default/1/Saved-Model-and-Encoders(Kaggle)/grasp_model_epoch_5.pt


In [16]:
# ===================================================================
# CELL 3: RE-DEFINE PRE-TRAINING BUILDING BLOCKS (FINAL CORRECTION)
# ===================================================================

# --- SMILES to Graph Utility Functions ---
# (This part is correct and remains the same)
ATOM_FEATURE_MAP = {'atomic_num': list(range(1, 119)), 'degree': list(range(6)),'formal_charge': list(range(-2, 3)), 'hybridization': [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED],'is_aromatic': [0, 1], 'is_in_ring': [0, 1]}
def one_hot_encode(v, c): e = [0] * (len(c) + 1); e[c.index(v) if v in c else -1] = 1; return e
def get_atom_features(a): f = []; f += one_hot_encode(a.GetAtomicNum(), ATOM_FEATURE_MAP['atomic_num']); f += one_hot_encode(a.GetDegree(), ATOM_FEATURE_MAP['degree']); f += one_hot_encode(a.GetFormalCharge(), ATOM_FEATURE_MAP['formal_charge']); f += one_hot_encode(a.GetHybridization(), ATOM_FEATURE_MAP['hybridization']); f += one_hot_encode(int(a.GetIsAromatic()), ATOM_FEATURE_MAP['is_aromatic']); f += one_hot_encode(int(a.IsInRing()), ATOM_FEATURE_MAP['is_in_ring']); return torch.tensor(f, dtype=torch.float)
def get_num_node_features(): return sum(len(c) + 1 for c in ATOM_FEATURE_MAP.values())
def smiles_to_graph_data(s):
    try:
        m = Chem.MolFromSmiles(s);
        if m is None: return None
        af = [get_atom_features(a) for a in m.GetAtoms()];
        if not af: return None
        x = torch.stack(af);
        ei = []; [ei.extend([(b.GetBeginAtomIdx(), b.GetEndAtomIdx()), (b.GetEndAtomIdx(), b.GetBeginAtomIdx())]) for b in m.GetBonds()]
        return Data(x=x, edge_index=torch.tensor(ei, dtype=torch.long).t().contiguous())
    except Exception: return None

# --- Core GRASP Model Architecture (THE ORIGINAL, CORRECT VERSION) ---
class GraphEncoder(nn.Module):
    def __init__(self, num_node_features, embedding_dim, num_layers, dropout):
        super(GraphEncoder, self).__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.dropout = dropout
        self.num_layers = num_layers  # The variable is named num_layers here
        for i in range(num_layers):
            in_dim = num_node_features if i == 0 else embedding_dim
            mlp = nn.Sequential(
                nn.Linear(in_dim, 2 * embedding_dim),
                nn.ReLU(),
                nn.Linear(2 * embedding_dim, embedding_dim)
            )
            self.convs.append(GINConv(mlp, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(embedding_dim))
            
    def forward(self, x, edge_index, batch):
        h = x
        # *** THE FIX IS HERE ***
        # Changed self.n_layers to self.num_layers to match the __init__ method
        for i in range(self.num_layers): 
            h = self.convs[i](h, edge_index)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        return global_add_pool(h, batch)

class GRASPModel(nn.Module):
    def __init__(self, g, l, p, d):
        super(GRASPModel, self).__init__()
        self.graph_encoder = GraphEncoder(get_num_node_features(), g, l, d)
    def forward(self, gb, sb):
        pass

print("✅ Pre-training building blocks defined with FINAL corrected GraphEncoder.")

✅ Pre-training building blocks defined with FINAL corrected GraphEncoder.


In [27]:
# ===================================================================
# CELL 4: (FINAL CORRECTED) MANUAL DATA LOADING & DOWNSTREAM PIPELINE
# ===================================================================
import pandas as pd
import requests
from io import StringIO, BytesIO # BytesIO is needed for compressed data

# --- CORRECTED: Function to manually download data (handles compression) ---
def download_moleculenet_csv(url):
    """
    Downloads a CSV file from a URL and returns a pandas DataFrame.
    It now correctly handles .gz compressed files like Tox21.
    """
    try:
        response = requests.get(url)
        response.raise_for_status()
        
        # Check if the URL suggests compression
        if url.endswith('.gz'):
            # For compressed data, we use BytesIO and tell pandas about the compression
            return pd.read_csv(BytesIO(response.content), compression='gzip')
        else:
            # For plain text CSVs, we use StringIO
            return pd.read_csv(StringIO(response.text))
            
    except requests.exceptions.RequestException as e:
        print(f"Error downloading data from {url}: {e}")
        return None

# --- NEW: Function to manually create DeepChem Dataset objects ---
def create_dc_dataset(dataframe, smiles_col, label_cols):
    """Converts a pandas DataFrame into a deepchem.data.Dataset object."""
    smiles = dataframe[smiles_col].tolist()
    labels = dataframe[label_cols].values
    return dc.data.NumpyDataset(X=np.zeros(len(smiles)), y=labels, ids=smiles)


# --- REVISED: load_moleculenet_dataset that uses manual download ---
def load_moleculenet_dataset(name):
    print(f"Bypassing dc.molnet. Manually loading '{name}' dataset...")
    
    if name == 'bbbp':
        url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv"
        df = download_moleculenet_csv(url)
        train_df, valid_df, test_df = np.split(df.sample(frac=1, random_state=42), [int(.8*len(df)), int(.9*len(df))])
        tasks, smiles_col, label_cols = ['p_np'], 'smiles', ['p_np']
        
    elif name == 'esol':
        url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"
        df = download_moleculenet_csv(url)
        train_df, valid_df, test_df = np.split(df.sample(frac=1, random_state=42), [int(.8*len(df)), int(.9*len(df))])
        tasks, smiles_col, label_cols = ['measured log solubility in mols per litre'], 'smiles', ['measured log solubility in mols per litre']

    elif name == 'tox21':
        url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz"
        df = download_moleculenet_csv(url)
        train_df, valid_df, test_df = np.split(df.sample(frac=1, random_state=42), [int(.8*len(df)), int(.9*len(df))])
        tasks = [col for col in df.columns if col not in ['mol_id', 'smiles', 'Set']]
        smiles_col, label_cols = 'smiles', tasks
        
    else:
        raise ValueError(f"Dataset {name} not supported for manual loading.")

    train_dc = create_dc_dataset(train_df, smiles_col, label_cols)
    valid_dc = create_dc_dataset(valid_df, smiles_col, label_cols)
    test_dc = create_dc_dataset(test_df, smiles_col, label_cols)

    print(f"Loaded {name}: Train {len(train_dc)}, Valid {len(valid_dc)}, Test {len(test_dc)}")
    return tasks, train_dc, valid_dc, test_dc

# The rest of the definitions are correct and remain the same.
class MoleculeNetDataset(Dataset):
    def __init__(self,d): self.s, self.l = d.ids, d.y
    def __len__(self): return len(self.s)
    def __getitem__(self,i): g=smiles_to_graph_data(self.s[i]); return (g, torch.tensor(self.l[i],dtype=torch.float)) if g else None

class DownstreamCollator:
    def __call__(self,b): b=[i for i in b if i is not None]; return (Batch.from_data_list([i[0] for i in b]), torch.stack([i[1] for i in b])) if b else (None,None)

class DownstreamModel(nn.Module):
    def __init__(self, p, nt):
        super().__init__();
        sd = torch.load(p, map_location=torch.device('cpu'));
        self.graph_encoder = GraphEncoder(get_num_node_features(), GRAPH_EMB_DIM, GRAPH_LAYERS, 0.1)
        esd = {k.replace('graph_encoder.',''): v for k,v in sd.items() if k.startswith('graph_encoder.')}
        self.graph_encoder.load_state_dict(esd);
        self.prediction_head = nn.Sequential(nn.Linear(GRAPH_EMB_DIM, GRAPH_EMB_DIM//2), nn.ReLU(), nn.Dropout(0.2), nn.Linear(GRAPH_EMB_DIM//2, nt))
    def freeze_encoder(self): [p.requires_grad_(False) for p in self.graph_encoder.parameters()]; print("Encoder frozen.")
    def unfreeze_encoder(self): [p.requires_grad_(True) for p in self.graph_encoder.parameters()]; print("Encoder unfrozen.")
    def forward(self, gb): return self.prediction_head(self.graph_encoder(gb.x, gb.edge_index, gb.batch))

print("✅ Downstream data utilities (with GZIP support) and model defined.")

✅ Downstream data utilities (with GZIP support) and model defined.


In [28]:
# ===================================================================
# CELL 5: MASTER FINE-TUNING SCRIPT
# ===================================================================
def run_finetuning_experiment_v2(dataset_name):
    # (The function definition is identical to the one in the previous correct responses)
    # ... It is quite long, so I will omit re-pasting it for brevity, but you should
    # ensure this cell contains the full `run_finetuning_experiment_v2` function.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n{'='*60}\n  STARTING EXPERIMENT FOR: {dataset_name.upper()} on device: {device}\n{'='*60}")

    tasks, train_dc, valid_dc, test_dc = load_moleculenet_dataset(dataset_name)
    num_tasks, task_type = len(tasks), 'classification' if dataset_name != 'esol' else 'regression'

    collator = DownstreamCollator()
    train_loader = DataLoader(MoleculeNetDataset(train_dc), FT_BATCH_SIZE, shuffle=True, collate_fn=collator, num_workers=NUM_WORKERS)
    valid_loader = DataLoader(MoleculeNetDataset(valid_dc), FT_BATCH_SIZE, shuffle=False, collate_fn=collator, num_workers=NUM_WORKERS)
    test_loader = DataLoader(MoleculeNetDataset(test_dc), FT_BATCH_SIZE, shuffle=False, collate_fn=collator, num_workers=NUM_WORKERS)

    model = DownstreamModel(PRETRAINED_CHECKPOINT_PATH, num_tasks).to(device)
    criterion = nn.BCEWithLogitsLoss(reduction='none') if task_type == 'classification' else nn.MSELoss()

    print(f"\n--- [STAGE 1] Head Warm-up ({FT_WARMUP_EPOCHS} epochs) ---")
    model.freeze_encoder()
    optimizer = optim.AdamW(model.prediction_head.parameters(), lr=HEAD_LEARNING_RATE)
    for epoch in range(FT_WARMUP_EPOCHS):
        model.train()
        for graph_batch, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{FT_WARMUP_EPOCHS} [Warm-up]"):
            if graph_batch is None: continue
            graph_batch, labels = graph_batch.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(graph_batch)
            mask = ~torch.isnan(labels)
            if task_type == 'classification': loss = criterion(outputs[mask], labels[mask]).mean() if mask.any() else torch.tensor(0.0)
            else: loss = criterion(outputs, labels)
            if loss.requires_grad: loss.backward(); optimizer.step()

    print(f"\n--- [STAGE 2] Full Fine-tuning ({FT_MAIN_EPOCHS} epochs) ---")
    model.unfreeze_encoder()
    optimizer = optim.AdamW([
        {'params': model.graph_encoder.parameters(), 'lr': ENCODER_LEARNING_RATE},
        {'params': model.prediction_head.parameters(), 'lr': HEAD_LEARNING_RATE}
    ])
    best_valid_metric, best_model_path = (-1 if task_type == 'classification' else float('inf')), f"/kaggle/working/best_{dataset_name}.pt"

    for epoch in range(FT_MAIN_EPOCHS):
        model.train(); total_loss = 0
        for graph_batch, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{FT_MAIN_EPOCHS} [Fine-tune]"):
            if graph_batch is None: continue
            graph_batch, labels = graph_batch.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(graph_batch)
            mask = ~torch.isnan(labels)
            if task_type == 'classification': loss = criterion(outputs[mask], labels[mask]).mean() if mask.any() else torch.tensor(0.0)
            else: loss = criterion(outputs, labels)
            if loss.requires_grad: loss.backward(); optimizer.step()
            total_loss += loss.item()
        
        model.eval(); all_preds, all_trues = [], []
        with torch.no_grad():
            for graph_batch, labels in valid_loader:
                if graph_batch is None: continue
                outputs = model(graph_batch.to(device))
                all_preds.append((torch.sigmoid(outputs) if task_type == 'classification' else outputs).cpu().numpy())
                all_trues.append(labels.cpu().numpy())
        all_preds, all_trues = np.concatenate(all_preds), np.concatenate(all_trues)

        if task_type == 'classification':
            aucs = [roc_auc_score(all_trues[:,i][~np.isnan(all_trues[:,i])], all_preds[:,i][~np.isnan(all_trues[:,i])]) for i in range(num_tasks) if len(np.unique(all_trues[:,i][~np.isnan(all_trues[:,i])])) > 1]
            valid_metric = np.mean(aucs) if aucs else 0.0
            print(f"Epoch {epoch+1} | Train Loss: {total_loss/len(train_loader):.4f} | Valid AUC: {valid_metric:.4f}")
            if valid_metric > best_valid_metric: best_valid_metric, _ = valid_metric, torch.save(model.state_dict(), best_model_path); print("  -> New best model saved")
        else:
            valid_metric = np.sqrt(mean_squared_error(all_trues, all_preds))
            print(f"Epoch {epoch+1} | Train Loss: {total_loss/len(train_loader):.4f} | Valid RMSE: {valid_metric:.4f}")
            if valid_metric < best_valid_metric: best_valid_metric, _ = valid_metric, torch.save(model.state_dict(), best_model_path); print("  -> New best model saved")

    print(f"\n--- Evaluating best model on {dataset_name} test set ---")
    model.load_state_dict(torch.load(best_model_path)); model.eval()
    all_preds_test, all_trues_test = [], []
    with torch.no_grad():
        for graph_batch, labels in tqdm(test_loader, desc="[Final Test]"):
            if graph_batch is None: continue
            outputs = model(graph_batch.to(device))
            all_preds_test.append((torch.sigmoid(outputs) if task_type == 'classification' else outputs).cpu().numpy())
            all_trues_test.append(labels.cpu().numpy())
    all_preds_test, all_trues_test = np.concatenate(all_preds_test), np.concatenate(all_trues_test)
    
    if task_type == 'classification':
        test_aucs = [roc_auc_score(all_trues_test[:,i][~np.isnan(all_trues_test[:,i])], all_preds_test[:,i][~np.isnan(all_trues_test[:,i])]) for i in range(num_tasks) if len(np.unique(all_trues_test[:,i][~np.isnan(all_trues_test[:,i])])) > 1]
        test_metric = np.mean(test_aucs) if test_aucs else 0.0
        print(f"\nFINAL TEST METRIC for {dataset_name}: ROC-AUC = {test_metric:.4f}")
    else:
        test_metric = np.sqrt(mean_squared_error(all_trues_test, all_preds_test))
        print(f"\nFINAL TEST METRIC for {dataset_name}: RMSE = {test_metric:.4f}")
    
    del model, train_loader, valid_loader, test_loader; gc.collect(); torch.cuda.empty_cache()
    return test_metric

print("✅ Master fine-tuning function defined.")

✅ Master fine-tuning function defined.


In [20]:
import deepchem as dc

# List all functions in dc.molnet that seem to be loaders
loader_functions = [func for func in dir(dc.molnet) if func.startswith('load_')]

print("Available loader functions in your deepchem.molnet version:")
print(loader_functions)

Available loader functions in your deepchem.molnet version:
['load_Platinum_Adsorption', 'load_bace_classification', 'load_bace_regression', 'load_bandgap', 'load_bbbc001', 'load_bbbc002', 'load_bbbc003', 'load_bbbc004', 'load_bbbc005', 'load_bbbp', 'load_cell_counting', 'load_chembl', 'load_chembl25', 'load_clearance', 'load_clintox', 'load_delaney', 'load_factors', 'load_freesolv', 'load_function', 'load_hiv', 'load_hopv', 'load_hppb', 'load_kaggle', 'load_kinase', 'load_lipo', 'load_mp_formation_energy', 'load_mp_metallicity', 'load_muv', 'load_nci', 'load_pcba', 'load_pdbbind', 'load_perovskite', 'load_ppb', 'load_qm7', 'load_qm8', 'load_qm9', 'load_sampl', 'load_sider', 'load_sweet', 'load_thermosol', 'load_tox21', 'load_toxcast', 'load_uspto', 'load_uv', 'load_zinc15']


In [29]:
# ===================================================================
# CELL 6: EXECUTE ALL EXPERIMENTS AND SUMMARIZE
# ===================================================================
results = {
    'bbbp': {'metric': 'ROC-AUC', 'score': run_finetuning_experiment_v2('bbbp')},
    'esol': {'metric': 'RMSE', 'score': run_finetuning_experiment_v2('esol')},
    'tox21': {'metric': 'ROC-AUC', 'score': run_finetuning_experiment_v2('tox21')}
}

# --- Summarize Results ---
print("\n" + "="*50)
print("           GRASP FINAL BENCHMARK RESULTS")
print("="*50)
print(f"| {'Dataset':<8} | {'Metric':<10}| {'Score':<11}|")
print(f"|{'-'*10}|{'-'*12}|{'-'*13}|")
print(f"| {'BBBP':<8} | {results['bbbp']['metric']:<10}| {results['bbbp']['score']:<11.4f}|")
print(f"| {'Tox21':<8} | {results['tox21']['metric']:<10}| {results['tox21']['score']:<11.4f}|")
print(f"| {'ESOL':<8} | {results['esol']['metric']:<10}| {results['esol']['score']:<11.4f}|")
print("="*50)


  STARTING EXPERIMENT FOR: BBBP on device: cuda
Bypassing dc.molnet. Manually loading 'bbbp' dataset...


  return bound(*args, **kwds)


Loaded bbbp: Train 1640, Valid 205, Test 205

--- [STAGE 1] Head Warm-up (5 epochs) ---
Encoder frozen.


Epoch 1/5 [Warm-up]: 100%|██████████| 52/52 [00:01<00:00, 29.86it/s]
Epoch 2/5 [Warm-up]: 100%|██████████| 52/52 [00:01<00:00, 29.11it/s]
Epoch 3/5 [Warm-up]: 100%|██████████| 52/52 [00:01<00:00, 29.34it/s]
Epoch 4/5 [Warm-up]: 100%|██████████| 52/52 [00:01<00:00, 29.65it/s]
Epoch 5/5 [Warm-up]: 100%|██████████| 52/52 [00:01<00:00, 29.59it/s]



--- [STAGE 2] Full Fine-tuning (25 epochs) ---
Encoder unfrozen.


Epoch 1/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.88it/s]


Epoch 1 | Train Loss: 0.4964 | Valid AUC: 0.8298
  -> New best model saved


Epoch 2/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.75it/s]


Epoch 2 | Train Loss: 0.4806 | Valid AUC: 0.8372
  -> New best model saved


Epoch 3/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 27.43it/s]


Epoch 3 | Train Loss: 0.4526 | Valid AUC: 0.8405
  -> New best model saved


Epoch 4/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.04it/s]


Epoch 4 | Train Loss: 0.4241 | Valid AUC: 0.8473
  -> New best model saved


Epoch 5/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.68it/s]


Epoch 5 | Train Loss: 0.4011 | Valid AUC: 0.8400


Epoch 6/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.86it/s]


Epoch 6 | Train Loss: 0.4034 | Valid AUC: 0.8528
  -> New best model saved


Epoch 7/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 28.28it/s]


Epoch 7 | Train Loss: 0.3908 | Valid AUC: 0.8586
  -> New best model saved


Epoch 8/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.04it/s]


Epoch 8 | Train Loss: 0.3730 | Valid AUC: 0.8664
  -> New best model saved


Epoch 9/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.66it/s]


Epoch 9 | Train Loss: 0.3515 | Valid AUC: 0.8694
  -> New best model saved


Epoch 10/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.50it/s]


Epoch 10 | Train Loss: 0.3548 | Valid AUC: 0.8674


Epoch 11/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.00it/s]


Epoch 11 | Train Loss: 0.3527 | Valid AUC: 0.8842
  -> New best model saved


Epoch 12/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.25it/s]


Epoch 12 | Train Loss: 0.3279 | Valid AUC: 0.8817


Epoch 13/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.04it/s]


Epoch 13 | Train Loss: 0.3395 | Valid AUC: 0.8851
  -> New best model saved


Epoch 14/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.99it/s]


Epoch 14 | Train Loss: 0.3237 | Valid AUC: 0.8896
  -> New best model saved


Epoch 15/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.49it/s]


Epoch 15 | Train Loss: 0.3117 | Valid AUC: 0.8944
  -> New best model saved


Epoch 16/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.17it/s]


Epoch 16 | Train Loss: 0.3134 | Valid AUC: 0.8901


Epoch 17/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 28.22it/s]


Epoch 17 | Train Loss: 0.3025 | Valid AUC: 0.8912


Epoch 18/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.11it/s]


Epoch 18 | Train Loss: 0.3040 | Valid AUC: 0.9045
  -> New best model saved


Epoch 19/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 28.24it/s]


Epoch 19 | Train Loss: 0.2940 | Valid AUC: 0.9055
  -> New best model saved


Epoch 20/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.30it/s]


Epoch 20 | Train Loss: 0.2902 | Valid AUC: 0.8970


Epoch 21/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.30it/s]


Epoch 21 | Train Loss: 0.2882 | Valid AUC: 0.9077
  -> New best model saved


Epoch 22/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.06it/s]


Epoch 22 | Train Loss: 0.2802 | Valid AUC: 0.8979


Epoch 23/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 30.46it/s]


Epoch 23 | Train Loss: 0.2772 | Valid AUC: 0.9045


Epoch 24/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.84it/s]


Epoch 24 | Train Loss: 0.2879 | Valid AUC: 0.9110
  -> New best model saved


Epoch 25/25 [Fine-tune]: 100%|██████████| 52/52 [00:01<00:00, 29.80it/s]


Epoch 25 | Train Loss: 0.2784 | Valid AUC: 0.9128
  -> New best model saved

--- Evaluating best model on bbbp test set ---


[Final Test]: 100%|██████████| 7/7 [00:00<00:00, 22.84it/s]



FINAL TEST METRIC for bbbp: ROC-AUC = 0.9400

  STARTING EXPERIMENT FOR: ESOL on device: cuda
Bypassing dc.molnet. Manually loading 'esol' dataset...


  return bound(*args, **kwds)


Loaded esol: Train 902, Valid 113, Test 113

--- [STAGE 1] Head Warm-up (5 epochs) ---
Encoder frozen.


Epoch 1/5 [Warm-up]: 100%|██████████| 29/29 [00:00<00:00, 38.60it/s]
Epoch 2/5 [Warm-up]: 100%|██████████| 29/29 [00:00<00:00, 37.24it/s]
Epoch 3/5 [Warm-up]: 100%|██████████| 29/29 [00:00<00:00, 38.19it/s]
Epoch 4/5 [Warm-up]: 100%|██████████| 29/29 [00:00<00:00, 39.34it/s]
Epoch 5/5 [Warm-up]: 100%|██████████| 29/29 [00:00<00:00, 37.58it/s]



--- [STAGE 2] Full Fine-tuning (25 epochs) ---
Encoder unfrozen.


Epoch 1/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.25it/s]


Epoch 1 | Train Loss: 2.6189 | Valid RMSE: 1.1844
  -> New best model saved


Epoch 2/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 37.45it/s]


Epoch 2 | Train Loss: 2.2085 | Valid RMSE: 1.1189
  -> New best model saved


Epoch 3/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.96it/s]


Epoch 3 | Train Loss: 1.9776 | Valid RMSE: 1.0360
  -> New best model saved


Epoch 4/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 40.27it/s]


Epoch 4 | Train Loss: 1.6972 | Valid RMSE: 1.0461


Epoch 5/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.41it/s]


Epoch 5 | Train Loss: 1.7412 | Valid RMSE: 0.9693
  -> New best model saved


Epoch 6/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.14it/s]


Epoch 6 | Train Loss: 1.7779 | Valid RMSE: 0.9398
  -> New best model saved


Epoch 7/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 36.71it/s]


Epoch 7 | Train Loss: 1.5539 | Valid RMSE: 0.9446


Epoch 8/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 37.72it/s]


Epoch 8 | Train Loss: 1.6547 | Valid RMSE: 0.9318
  -> New best model saved


Epoch 9/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.29it/s]


Epoch 9 | Train Loss: 1.3516 | Valid RMSE: 0.9130
  -> New best model saved


Epoch 10/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.85it/s]


Epoch 10 | Train Loss: 1.3035 | Valid RMSE: 0.8868
  -> New best model saved


Epoch 11/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 37.40it/s]


Epoch 11 | Train Loss: 1.2321 | Valid RMSE: 0.8688
  -> New best model saved


Epoch 12/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.15it/s]


Epoch 12 | Train Loss: 1.2562 | Valid RMSE: 0.8602
  -> New best model saved


Epoch 13/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 32.48it/s]


Epoch 13 | Train Loss: 1.3935 | Valid RMSE: 0.8557
  -> New best model saved


Epoch 14/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 35.63it/s]


Epoch 14 | Train Loss: 1.3107 | Valid RMSE: 0.8553
  -> New best model saved


Epoch 15/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.70it/s]


Epoch 15 | Train Loss: 1.1215 | Valid RMSE: 0.8410
  -> New best model saved


Epoch 16/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 37.60it/s]


Epoch 16 | Train Loss: 1.0148 | Valid RMSE: 0.8436


Epoch 17/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 36.09it/s]


Epoch 17 | Train Loss: 1.1825 | Valid RMSE: 0.8096
  -> New best model saved


Epoch 18/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.88it/s]


Epoch 18 | Train Loss: 1.1373 | Valid RMSE: 0.8190


Epoch 19/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.16it/s]


Epoch 19 | Train Loss: 1.0838 | Valid RMSE: 0.7987
  -> New best model saved


Epoch 20/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.31it/s]


Epoch 20 | Train Loss: 1.0747 | Valid RMSE: 0.7985
  -> New best model saved


Epoch 21/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.01it/s]


Epoch 21 | Train Loss: 1.0390 | Valid RMSE: 0.7973
  -> New best model saved


Epoch 22/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.63it/s]


Epoch 22 | Train Loss: 0.9608 | Valid RMSE: 0.7931
  -> New best model saved


Epoch 23/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 39.38it/s]


Epoch 23 | Train Loss: 1.0152 | Valid RMSE: 0.7690
  -> New best model saved


Epoch 24/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 38.13it/s]


Epoch 24 | Train Loss: 0.9602 | Valid RMSE: 0.7635
  -> New best model saved


Epoch 25/25 [Fine-tune]: 100%|██████████| 29/29 [00:00<00:00, 37.09it/s]


Epoch 25 | Train Loss: 0.9365 | Valid RMSE: 0.7843

--- Evaluating best model on esol test set ---


[Final Test]: 100%|██████████| 4/4 [00:00<00:00, 20.22it/s]



FINAL TEST METRIC for esol: RMSE = 0.8986

  STARTING EXPERIMENT FOR: TOX21 on device: cuda
Bypassing dc.molnet. Manually loading 'tox21' dataset...


  return bound(*args, **kwds)


Loaded tox21: Train 6264, Valid 783, Test 784

--- [STAGE 1] Head Warm-up (5 epochs) ---
Encoder frozen.


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 1/5 [Warm-up]: 100%|██████████| 196/196 [00:05<00:00, 36.02it/s]
  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 2/5 [Warm-up]: 100%|██████████| 196/196 [00:05<00:00, 36.65it/s]
Epoch 3/5 [Warm-up]: 100%|██████████| 196/196 [00:05<00:00, 37.13it/s]
Epoch 4/5 [Warm-up]: 100%|██████████| 196/196 [00:05<00:00, 35.15it/s]
Epoch 5/5 [Warm-up]: 100%|██████████| 196/196 [00:05<00:00, 37.67it/s]



--- [STAGE 2] Full Fine-tuning (25 epochs) ---
Encoder unfrozen.


Epoch 1/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.53it/s]


Epoch 1 | Train Loss: 0.2618 | Valid AUC: 0.7105
  -> New best model saved


Epoch 2/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 38.17it/s]


Epoch 2 | Train Loss: 0.2466 | Valid AUC: 0.7319
  -> New best model saved


Epoch 3/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.29it/s]


Epoch 3 | Train Loss: 0.2390 | Valid AUC: 0.7437
  -> New best model saved


Epoch 4/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.83it/s]


Epoch 4 | Train Loss: 0.2330 | Valid AUC: 0.7505
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 5/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.73it/s]


Epoch 5 | Train Loss: 0.2286 | Valid AUC: 0.7569
  -> New best model saved


Epoch 6/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.59it/s]


Epoch 6 | Train Loss: 0.2239 | Valid AUC: 0.7607
  -> New best model saved


Epoch 7/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.23it/s]


Epoch 7 | Train Loss: 0.2205 | Valid AUC: 0.7697
  -> New best model saved


Epoch 8/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.83it/s]


Epoch 8 | Train Loss: 0.2175 | Valid AUC: 0.7701
  -> New best model saved


Epoch 9/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.90it/s]


Epoch 9 | Train Loss: 0.2156 | Valid AUC: 0.7737
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 10/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.15it/s]


Epoch 10 | Train Loss: 0.2113 | Valid AUC: 0.7849
  -> New best model saved


Epoch 11/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.56it/s]


Epoch 11 | Train Loss: 0.2115 | Valid AUC: 0.7835


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 12/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.68it/s]


Epoch 12 | Train Loss: 0.2073 | Valid AUC: 0.7891
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 13/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.25it/s]


Epoch 13 | Train Loss: 0.2062 | Valid AUC: 0.7923
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 14/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.08it/s]


Epoch 14 | Train Loss: 0.2046 | Valid AUC: 0.7914


Epoch 15/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 35.83it/s]


Epoch 15 | Train Loss: 0.2010 | Valid AUC: 0.7929
  -> New best model saved


Epoch 16/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.26it/s]


Epoch 16 | Train Loss: 0.2011 | Valid AUC: 0.8007
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 17/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.24it/s]


Epoch 17 | Train Loss: 0.1983 | Valid AUC: 0.8021
  -> New best model saved


Epoch 18/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.12it/s]


Epoch 18 | Train Loss: 0.1949 | Valid AUC: 0.8062
  -> New best model saved


Epoch 19/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.08it/s]


Epoch 19 | Train Loss: 0.1952 | Valid AUC: 0.8089
  -> New best model saved


Epoch 20/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 35.90it/s]


Epoch 20 | Train Loss: 0.1958 | Valid AUC: 0.8086


Epoch 21/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.12it/s]


Epoch 21 | Train Loss: 0.1933 | Valid AUC: 0.8110
  -> New best model saved


Epoch 22/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.00it/s]


Epoch 22 | Train Loss: 0.1934 | Valid AUC: 0.8114
  -> New best model saved


  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 23/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.87it/s]


Epoch 23 | Train Loss: 0.1916 | Valid AUC: 0.8141
  -> New best model saved


Epoch 24/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 36.78it/s]


Epoch 24 | Train Loss: 0.1894 | Valid AUC: 0.8126


  value = torch.cat(values, dim=cat_dim or 0, out=out)
Epoch 25/25 [Fine-tune]: 100%|██████████| 196/196 [00:05<00:00, 37.29it/s]


Epoch 25 | Train Loss: 0.1902 | Valid AUC: 0.8128

--- Evaluating best model on tox21 test set ---


[Final Test]: 100%|██████████| 25/25 [00:00<00:00, 26.87it/s]



FINAL TEST METRIC for tox21: ROC-AUC = 0.8186

           GRASP FINAL BENCHMARK RESULTS
| Dataset  | Metric    | Score      |
|----------|------------|-------------|
| BBBP     | ROC-AUC   | 0.9400     |
| Tox21    | ROC-AUC   | 0.8186     |
| ESOL     | RMSE      | 0.8986     |
