# Experiment 116: TRUE Graph Neural Network

**Goal**: Implement a TRUE GNN that operates on molecular graphs with message-passing layers.

**Key Differences from Previous "GNN" Experiments**:
1. Convert SMILES to molecular graphs (atoms as nodes, bonds as edges)
2. Use GCNConv message-passing layers from PyTorch Geometric
3. Use atom features (atomic number, degree, hybridization, etc.)
4. Apply global pooling to get molecule-level embeddings

**Hypothesis**: A true GNN may capture structural patterns that generalize to unseen solvents, potentially changing the CV-LB relationship.

**CRITICAL**: The model class `TrueGNNModel` will be used in BOTH CV computation AND submission cells.

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import tqdm
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float32)  # Use float32 for GNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# PyTorch Geometric imports
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool
from torch_geometric.data import Data, Batch
print('PyTorch Geometric imported successfully')

PyTorch Geometric imported successfully


In [3]:
# RDKit imports for molecular graph construction
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
print('RDKit imported successfully')

RDKit imported successfully


In [4]:
# Data loading functions
DATA_PATH = '/home/data'

INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[INPUT_LABELS_FULL_SOLVENT]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[INPUT_LABELS_SINGLE_SOLVENT]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

Data loading functions defined


In [5]:
# Load SMILES lookup
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f'SMILES lookup: {SMILES_DF.shape}')
print(SMILES_DF.head())

SMILES lookup: (26, 1)
                                           solvent smiles
SOLVENT NAME                                             
Cyclohexane                                      C1CCCCC1
Ethyl Acetate                                   O=C(OCC)C
Acetic Acid                                       CC(=O)O
2-Methyltetrahydrofuran [2-MeTHF]              O1C(C)CCC1
1,1,1,3,3,3-Hexafluoropropan-2-ol  C(C(F)(F)F)(C(F)(F)F)O


In [6]:
# Molecular graph construction from SMILES
def smiles_to_graph(smiles):
    """Convert SMILES to PyTorch Geometric Data object.
    
    Node features (per atom):
    - Atomic number (one-hot encoded for common atoms: C, N, O, S, F, Cl, Br, I, P, other)
    - Degree (0-5)
    - Formal charge (-2 to +2)
    - Hybridization (sp, sp2, sp3, sp3d, sp3d2)
    - Is aromatic (0/1)
    - Number of hydrogens (0-4)
    
    Edge features:
    - Bond type (single, double, triple, aromatic)
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        # Return a dummy graph for invalid SMILES
        return Data(
            x=torch.zeros((1, 23), dtype=torch.float32),
            edge_index=torch.zeros((2, 0), dtype=torch.long)
        )
    
    # Atom features
    atom_features = []
    for atom in mol.GetAtoms():
        # Atomic number one-hot (C=6, N=7, O=8, S=16, F=9, Cl=17, Br=35, I=53, P=15, other)
        atomic_num = atom.GetAtomicNum()
        atom_type = [0] * 10
        if atomic_num == 6:  # C
            atom_type[0] = 1
        elif atomic_num == 7:  # N
            atom_type[1] = 1
        elif atomic_num == 8:  # O
            atom_type[2] = 1
        elif atomic_num == 16:  # S
            atom_type[3] = 1
        elif atomic_num == 9:  # F
            atom_type[4] = 1
        elif atomic_num == 17:  # Cl
            atom_type[5] = 1
        elif atomic_num == 35:  # Br
            atom_type[6] = 1
        elif atomic_num == 53:  # I
            atom_type[7] = 1
        elif atomic_num == 15:  # P
            atom_type[8] = 1
        else:
            atom_type[9] = 1
        
        # Degree (0-5)
        degree = min(atom.GetDegree(), 5)
        degree_onehot = [0] * 6
        degree_onehot[degree] = 1
        
        # Formal charge (-2 to +2)
        charge = atom.GetFormalCharge()
        charge_onehot = [0] * 5
        charge_idx = min(max(charge + 2, 0), 4)
        charge_onehot[charge_idx] = 1
        
        # Is aromatic
        is_aromatic = [1 if atom.GetIsAromatic() else 0]
        
        # Number of hydrogens (0-4)
        num_h = min(atom.GetTotalNumHs(), 4)
        
        features = atom_type + degree_onehot + charge_onehot + is_aromatic + [num_h / 4.0]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float32)
    
    # Edge index (bonds)
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])  # Undirected graph
    
    if len(edge_index) > 0:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    else:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index)

# Test with a simple molecule
test_graph = smiles_to_graph('CCO')  # Ethanol
print(f'Test graph: {test_graph.num_nodes} nodes, {test_graph.num_edges} edges')
print(f'Node features shape: {test_graph.x.shape}')

Test graph: 3 nodes, 4 edges
Node features shape: torch.Size([3, 23])


In [7]:
# Pre-compute molecular graphs for all solvents
SOLVENT_GRAPHS = {}
for solvent in SMILES_DF.index:
    smiles = SMILES_DF.loc[solvent, 'solvent smiles']
    if isinstance(smiles, str):
        SOLVENT_GRAPHS[solvent] = smiles_to_graph(smiles)
    else:
        SOLVENT_GRAPHS[solvent] = smiles_to_graph('C')  # Dummy

print(f'Pre-computed graphs for {len(SOLVENT_GRAPHS)} solvents')

# Show some statistics
for solvent in list(SOLVENT_GRAPHS.keys())[:5]:
    g = SOLVENT_GRAPHS[solvent]
    print(f'  {solvent}: {g.num_nodes} atoms, {g.num_edges} bonds')

Pre-computed graphs for 26 solvents
  Cyclohexane: 6 atoms, 12 bonds
  Ethyl Acetate: 6 atoms, 10 bonds
  Acetic Acid: 4 atoms, 6 bonds
  2-Methyltetrahydrofuran [2-MeTHF]: 6 atoms, 12 bonds
  1,1,1,3,3,3-Hexafluoropropan-2-ol: 10 atoms, 18 bonds


In [8]:
# Load Spange descriptors for process conditions
SPANGE_DF = pd.read_csv(f'{DATA_PATH}/spange_descriptors_lookup.csv', index_col=0)
print(f'Spange descriptors: {SPANGE_DF.shape}')

Spange descriptors: (26, 13)


In [9]:
# TRUE GNN Model with message-passing layers
class TrueGNNModel(nn.Module):
    """True Graph Neural Network for solvent yield prediction.
    
    Architecture:
    1. GCN layers for message passing on molecular graph
    2. Global mean pooling to get molecule-level embedding
    3. Concatenate with process conditions (T, RT, kinetic features)
    4. MLP head for prediction
    
    This is the SAME class used in both CV and submission cells.
    """
    def __init__(self, data='single', hidden_dim=64, num_gnn_layers=3):
        super().__init__()
        self.data_type = data
        self.hidden_dim = hidden_dim
        
        # GNN layers
        self.node_embed = nn.Linear(23, hidden_dim)  # 23 = node feature dim
        self.convs = nn.ModuleList([
            GCNConv(hidden_dim, hidden_dim) for _ in range(num_gnn_layers)
        ])
        self.conv_bns = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim) for _ in range(num_gnn_layers)
        ])
        
        # Process condition encoder (5 kinetic features + 13 Spange = 18)
        self.process_dim = 5 + 13  # kinetic + spange
        self.process_encoder = nn.Sequential(
            nn.Linear(self.process_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Output head
        # For mixtures: 2 * hidden_dim (two solvents) + 32 (process) + 1 (mixture ratio)
        # For single: hidden_dim + 32 (process)
        if data == 'full':
            head_input_dim = 2 * hidden_dim + 32 + 1
        else:
            head_input_dim = hidden_dim + 32
        
        self.head = nn.Sequential(
            nn.Linear(head_input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 3),
            nn.Sigmoid()
        )
        
        # Store training data for prediction
        self.spange_df = SPANGE_DF
        self.solvent_graphs = SOLVENT_GRAPHS
        
    def encode_molecule(self, graph):
        """Encode a molecular graph to a fixed-size embedding."""
        x = self.node_embed(graph.x)
        
        for conv, bn in zip(self.convs, self.conv_bns):
            x = conv(x, graph.edge_index)
            x = bn(x)
            x = F.relu(x)
        
        # Global mean pooling
        if hasattr(graph, 'batch') and graph.batch is not None:
            x = global_mean_pool(x, graph.batch)
        else:
            x = x.mean(dim=0, keepdim=True)
        
        return x
    
    def get_kinetic_features(self, X):
        """Extract kinetic features from input data."""
        X_vals = X[INPUT_LABELS_NUMERIC].values.astype(np.float32)
        temp_c = X_vals[:, 1:2]
        time_m = X_vals[:, 0:1]
        temp_k = temp_c + 273.15
        inv_temp = 1000.0 / temp_k
        log_time = np.log(time_m + 1e-6)
        interaction = inv_temp * log_time
        return np.hstack([X_vals, inv_temp, log_time, interaction])
    
    def train_model(self, X_train, y_train):
        """Train the GNN model."""
        self.to(device)
        self.train()
        
        # Prepare training data
        kinetic_feats = self.get_kinetic_features(X_train)
        y_vals = y_train.values.astype(np.float32)
        
        # Get solvent graphs and Spange features
        if self.data_type == 'full':
            solvent_a_names = X_train["SOLVENT A NAME"].values
            solvent_b_names = X_train["SOLVENT B NAME"].values
            mixture_ratios = X_train["SolventB%"].values.astype(np.float32).reshape(-1, 1)
            spange_a = self.spange_df.loc[solvent_a_names].values.astype(np.float32)
            spange_b = self.spange_df.loc[solvent_b_names].values.astype(np.float32)
            # Weighted average of Spange features
            spange_feats = spange_a * (1 - mixture_ratios) + spange_b * mixture_ratios
        else:
            solvent_names = X_train["SOLVENT NAME"].values
            spange_feats = self.spange_df.loc[solvent_names].values.astype(np.float32)
        
        # Combine kinetic and Spange features
        process_feats = np.hstack([kinetic_feats, spange_feats])
        
        # Training loop
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        criterion = nn.HuberLoss()
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=20, verbose=False
        )
        
        n_samples = len(X_train)
        batch_size = min(32, n_samples)
        
        for epoch in range(200):
            epoch_loss = 0.0
            indices = np.random.permutation(n_samples)
            
            for start in range(0, n_samples, batch_size):
                end = min(start + batch_size, n_samples)
                batch_idx = indices[start:end]
                
                # Get batch data
                batch_process = torch.tensor(process_feats[batch_idx], dtype=torch.float32).to(device)
                batch_y = torch.tensor(y_vals[batch_idx], dtype=torch.float32).to(device)
                
                # Encode molecules
                if self.data_type == 'full':
                    # Encode both solvents
                    graphs_a = [self.solvent_graphs[solvent_a_names[i]].clone() for i in batch_idx]
                    graphs_b = [self.solvent_graphs[solvent_b_names[i]].clone() for i in batch_idx]
                    
                    batch_a = Batch.from_data_list(graphs_a).to(device)
                    batch_b = Batch.from_data_list(graphs_b).to(device)
                    
                    embed_a = self.encode_molecule(batch_a)
                    embed_b = self.encode_molecule(batch_b)
                    
                    batch_ratio = torch.tensor(mixture_ratios[batch_idx], dtype=torch.float32).to(device)
                    
                    # Concatenate embeddings
                    process_encoded = self.process_encoder(batch_process)
                    combined = torch.cat([embed_a, embed_b, process_encoded, batch_ratio], dim=1)
                else:
                    graphs = [self.solvent_graphs[solvent_names[i]].clone() for i in batch_idx]
                    batch_graphs = Batch.from_data_list(graphs).to(device)
                    embed = self.encode_molecule(batch_graphs)
                    
                    process_encoded = self.process_encoder(batch_process)
                    combined = torch.cat([embed, process_encoded], dim=1)
                
                # Forward pass
                optimizer.zero_grad()
                pred = self.head(combined)
                loss = criterion(pred, batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
                optimizer.step()
                
                epoch_loss += loss.item() * len(batch_idx)
            
            scheduler.step(epoch_loss / n_samples)
    
    def predict(self, X):
        """Make predictions."""
        self.eval()
        
        kinetic_feats = self.get_kinetic_features(X)
        
        if self.data_type == 'full':
            solvent_a_names = X["SOLVENT A NAME"].values
            solvent_b_names = X["SOLVENT B NAME"].values
            mixture_ratios = X["SolventB%"].values.astype(np.float32).reshape(-1, 1)
            spange_a = self.spange_df.loc[solvent_a_names].values.astype(np.float32)
            spange_b = self.spange_df.loc[solvent_b_names].values.astype(np.float32)
            spange_feats = spange_a * (1 - mixture_ratios) + spange_b * mixture_ratios
        else:
            solvent_names = X["SOLVENT NAME"].values
            spange_feats = self.spange_df.loc[solvent_names].values.astype(np.float32)
        
        process_feats = np.hstack([kinetic_feats, spange_feats])
        
        with torch.no_grad():
            batch_process = torch.tensor(process_feats, dtype=torch.float32).to(device)
            
            if self.data_type == 'full':
                graphs_a = [self.solvent_graphs[s].clone() for s in solvent_a_names]
                graphs_b = [self.solvent_graphs[s].clone() for s in solvent_b_names]
                
                batch_a = Batch.from_data_list(graphs_a).to(device)
                batch_b = Batch.from_data_list(graphs_b).to(device)
                
                embed_a = self.encode_molecule(batch_a)
                embed_b = self.encode_molecule(batch_b)
                
                batch_ratio = torch.tensor(mixture_ratios, dtype=torch.float32).to(device)
                
                process_encoded = self.process_encoder(batch_process)
                combined = torch.cat([embed_a, embed_b, process_encoded, batch_ratio], dim=1)
            else:
                graphs = [self.solvent_graphs[s].clone() for s in solvent_names]
                batch_graphs = Batch.from_data_list(graphs).to(device)
                embed = self.encode_molecule(batch_graphs)
                
                process_encoded = self.process_encoder(batch_process)
                combined = torch.cat([embed, process_encoded], dim=1)
            
            pred = self.head(combined)
        
        return pred.cpu()

print('TrueGNNModel defined - will be used in both CV and submission cells')

TrueGNNModel defined - will be used in both CV and submission cells


In [10]:
# Cross-validation to compute CV score
print("Computing CV score...")

# Single solvent CV
X_single, Y_single = load_data("single_solvent")
single_mses = []

for fold_idx, split in enumerate(generate_leave_one_out_splits(X_single, Y_single)):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = TrueGNNModel(data='single')  # SAME CLASS AS SUBMISSION
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X).numpy()
    targets = test_Y.values
    
    mse = np.mean((predictions - targets) ** 2)
    single_mses.append(mse)
    
    if fold_idx % 6 == 0:
        print(f"  Fold {fold_idx}: MSE = {mse:.6f}")

single_mse = np.mean(single_mses)
print(f"\nSingle solvent MSE: {single_mse:.6f}")

Computing CV score...


  Fold 0: MSE = 0.045959


  Fold 6: MSE = 0.006228


  Fold 12: MSE = 0.003432


  Fold 18: MSE = 0.006570



Single solvent MSE: 0.012039


In [11]:
# Full data CV
X_full, Y_full = load_data("full")
full_mses = []

for fold_idx, split in enumerate(generate_leave_one_ramp_out_splits(X_full, Y_full)):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = TrueGNNModel(data='full')  # SAME CLASS AS SUBMISSION
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X).numpy()
    targets = test_Y.values
    
    mse = np.mean((predictions - targets) ** 2)
    full_mses.append(mse)
    
    if fold_idx % 3 == 0:
        print(f"  Fold {fold_idx}: MSE = {mse:.6f}")

full_mse = np.mean(full_mses)
print(f"\nFull data MSE: {full_mse:.6f}")

  Fold 0: MSE = 0.008819


  Fold 3: MSE = 0.010706


  Fold 6: MSE = 0.015453


  Fold 9: MSE = 0.004356


  Fold 12: MSE = 0.005905

Full data MSE: 0.010542


In [None]:
# Combined CV score
cv_score = (single_mse + full_mse) / 2
print(f"\n=== CV Results ===")
print(f"Single solvent MSE: {single_mse:.6f}")
print(f"Full data MSE: {full_mse:.6f}")
print(f"Combined CV score: {cv_score:.6f}")

# Save metrics
import json
metrics = {
    'cv_score': cv_score,
    'single_mse': single_mse,
    'full_mse': full_mse
}
with open('/home/code/experiments/116_true_gnn/metrics.json', 'w') as f:
    json.dump(metrics, f)

print(f"\nComparison with best CV: 0.0081")
print(f"This experiment: {cv_score:.6f}")
if cv_score < 0.0081:
    print("IMPROVEMENT! This is better than best CV.")
else:
    print(f"No improvement. Difference: {cv_score - 0.0081:.6f}")

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = TrueGNNModel(data='single')  # SAME CLASS AS CV
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = TrueGNNModel(data='full')  # SAME CLASS AS CV
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

print(f"Submission saved with {len(submission)} rows")
print(submission.head())

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################