# Experiment 020: Graph Neural Network (GNN) for Molecular Property Prediction

**Key insight from research:**
- Paper arxiv:2512.19530 achieved MSE 0.0039 using GNN (25x better than tabular ensembles)
- GNN can learn molecular structure patterns that generalize to unseen solvents
- This is the only approach with demonstrated target-level performance

**Architecture:**
- Use RDKit to convert solvent SMILES to molecular graphs
- Graph Convolutional Network (GCN) for molecular encoding
- Combine molecular embeddings with process conditions (Temperature, Residence Time)
- Per-target prediction heads

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from abc import ABC
import tqdm
import warnings
warnings.filterwarnings('ignore')

# GNN imports
from rdkit import Chem
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool

DATA_PATH = '/home/data'
torch.set_default_dtype(torch.float32)  # PyG works better with float32
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')

CUDA available: True
GPU: NVIDIA H100 80GB HBM3


In [2]:
# --- UTILITY FUNCTIONS ---
TARGET_LABELS = ["Product 2", "Product 3", "SM"]

def load_data(name="full"):
    assert name in ["full", "single_solvent"]
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT NAME"]]
    Y = df[TARGET_LABELS]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & 
                 (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

# Load SMILES lookup
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
SMILES_DICT = SMILES_DF['solvent smiles'].to_dict()
print(f"SMILES lookup: {len(SMILES_DICT)} solvents")

SMILES lookup: 26 solvents


In [3]:
# --- BASE CLASSES ---
class SmilesFeaturizer(ABC):
    def __init__(self): raise NotImplementedError
    def featurize(self, X): raise NotImplementedError

class BaseModel(ABC):
    def __init__(self): pass
    def train_model(self, X_train, y_train): raise NotImplementedError
    def predict(self): raise NotImplementedError

In [4]:
# --- MOLECULAR GRAPH UTILITIES ---
def smiles_to_graph(smiles):
    """Convert SMILES string to PyTorch Geometric Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        # Fallback for invalid SMILES (e.g., mixtures like "O.CC#N")
        # Use first component
        smiles_parts = smiles.split('.')
        mol = Chem.MolFromSmiles(smiles_parts[0])
        if mol is None:
            # Return a simple water molecule as fallback
            mol = Chem.MolFromSmiles('O')
    
    # Atom features: [atomic_num, degree, formal_charge, hybridization, aromatic]
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic())
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float32)
    
    # Edge index (bonds)
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])  # Undirected
    
    if len(edge_index) == 0:
        # Single atom molecule
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

# Test
test_graph = smiles_to_graph('CCO')  # Ethanol
print(f"Ethanol graph: {test_graph.num_nodes} atoms, {test_graph.num_edges} edges")
test_graph = smiles_to_graph('O.CC#N')  # Water.Acetonitrile mixture
print(f"Water.Acetonitrile graph: {test_graph.num_nodes} atoms, {test_graph.num_edges} edges")

Ethanol graph: 3 atoms, 4 edges
Water.Acetonitrile graph: 4 atoms, 4 edges


In [5]:
# --- GNN MODEL ---
class MolecularGNN(nn.Module):
    """Graph Neural Network for molecular property prediction.
    
    Architecture:
    - 3 GCN layers for molecular encoding
    - Global mean pooling to get molecule-level embedding
    - Combine with process conditions (Temperature, Residence Time)
    - MLP head for prediction
    """
    def __init__(self, atom_features=5, hidden_dim=64, output_dim=3):
        super().__init__()
        
        # GCN layers
        self.conv1 = GCNConv(atom_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Process condition encoder
        self.condition_encoder = nn.Sequential(
            nn.Linear(3, 32),  # [RT, Temp, SolventB%]
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Prediction head
        self.head = nn.Sequential(
            nn.Linear(hidden_dim + 32, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, output_dim),
            nn.Sigmoid()
        )
    
    def forward(self, graph_batch, conditions):
        # Graph encoding
        x, edge_index, batch = graph_batch.x, graph_batch.edge_index, graph_batch.batch
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        
        # Global pooling
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]
        
        # Condition encoding
        cond = self.condition_encoder(conditions)  # [batch_size, 32]
        
        # Combine and predict
        combined = torch.cat([x, cond], dim=1)
        out = self.head(combined)
        
        return out

print("MolecularGNN defined")

MolecularGNN defined


In [6]:
# --- GNN MODEL WRAPPER ---
class GNNModel(BaseModel):
    """GNN-based model for solvent yield prediction.
    
    For single solvents: Use molecular graph directly
    For mixed solvents: Average embeddings of both solvents weighted by SolventB%
    """
    def __init__(self, data='single'):
        self.data_type = data
        self.mixed = (data == 'full')
        self.model = None
        self.scaler = StandardScaler()
        
        # Pre-compute graphs for all solvents
        self.solvent_graphs = {}
        for name, smiles in SMILES_DICT.items():
            self.solvent_graphs[name] = smiles_to_graph(smiles)
    
    def _get_conditions(self, X):
        """Extract process conditions."""
        rt = X['Residence Time'].values.reshape(-1, 1)
        temp = X['Temperature'].values.reshape(-1, 1)
        
        if self.mixed:
            pct = X['SolventB%'].values.reshape(-1, 1) / 100.0
            return np.hstack([rt, temp, pct])
        else:
            return np.hstack([rt, temp, np.zeros((len(X), 1))])
    
    def _get_graphs(self, X):
        """Get molecular graphs for samples."""
        graphs = []
        if self.mixed:
            for _, row in X.iterrows():
                # For mixed solvents, use the primary solvent (A)
                # Could also try averaging embeddings
                solvent = row['SOLVENT A NAME']
                if solvent in self.solvent_graphs:
                    graphs.append(self.solvent_graphs[solvent])
                else:
                    graphs.append(smiles_to_graph('O'))  # Fallback
        else:
            for _, row in X.iterrows():
                solvent = row['SOLVENT NAME']
                if solvent in self.solvent_graphs:
                    graphs.append(self.solvent_graphs[solvent])
                else:
                    graphs.append(smiles_to_graph('O'))  # Fallback
        return graphs
    
    def train_model(self, X_train, y_train):
        # Get conditions and graphs
        conditions = self._get_conditions(X_train)
        conditions_scaled = self.scaler.fit_transform(conditions)
        graphs = self._get_graphs(X_train)
        y = y_train.values
        
        # Initialize model
        self.model = MolecularGNN(atom_features=5, hidden_dim=64, output_dim=3).to(device)
        
        # Training
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
        criterion = nn.MSELoss()
        
        # Create batches
        batch_size = 32
        n_samples = len(graphs)
        
        self.model.train()
        for epoch in range(100):
            # Shuffle
            indices = np.random.permutation(n_samples)
            
            for start in range(0, n_samples, batch_size):
                end = min(start + batch_size, n_samples)
                batch_idx = indices[start:end]
                
                # Batch graphs
                batch_graphs = [graphs[i] for i in batch_idx]
                graph_batch = Batch.from_data_list(batch_graphs).to(device)
                
                # Batch conditions
                batch_cond = torch.tensor(conditions_scaled[batch_idx], dtype=torch.float32).to(device)
                
                # Batch targets
                batch_y = torch.tensor(y[batch_idx], dtype=torch.float32).to(device)
                
                # Forward
                optimizer.zero_grad()
                pred = self.model(graph_batch, batch_cond)
                loss = criterion(pred, batch_y)
                loss.backward()
                optimizer.step()
    
    def predict(self, X):
        conditions = self._get_conditions(X)
        conditions_scaled = self.scaler.transform(conditions)
        graphs = self._get_graphs(X)
        
        self.model.eval()
        with torch.no_grad():
            graph_batch = Batch.from_data_list(graphs).to(device)
            cond_tensor = torch.tensor(conditions_scaled, dtype=torch.float32).to(device)
            pred = self.model(graph_batch, cond_tensor)
        
        # Convert to double for template compatibility
        return pred.cpu().double()

print("GNNModel defined")

GNNModel defined


In [7]:
# --- QUICK VALIDATION TEST ---
print("Testing GNNModel...")
X_test, Y_test = load_data("single_solvent")

errors = []
for i, ((train_X, train_Y), (test_X, test_Y)) in enumerate(generate_leave_one_out_splits(X_test, Y_test)):
    if i >= 3: break
    model = GNNModel(data='single')
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mae = np.mean(np.abs(preds - test_Y.values))
    errors.append(mae)
    solvent = test_X['SOLVENT NAME'].iloc[0]
    print(f"Single Fold {i} ({solvent}): MAE = {mae:.4f}")

print(f"\nSingle solvent quick test MAE: {np.mean(errors):.4f}")

# Test full data
print("\nTesting on full data...")
X_full, Y_full = load_data("full")
errors_full = []
for i, ((train_X, train_Y), (test_X, test_Y)) in enumerate(generate_leave_one_ramp_out_splits(X_full, Y_full)):
    if i >= 3: break
    model = GNNModel(data='full')
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mae = np.mean(np.abs(preds - test_Y.values))
    errors_full.append(mae)
    print(f"Full Fold {i}: MAE = {mae:.4f}")

print(f"\nFull data quick test MAE: {np.mean(errors_full):.4f}")

Testing GNNModel...


Single Fold 0 (1,1,1,3,3,3-Hexafluoropropan-2-ol): MAE = 0.2505


Single Fold 1 (2,2,2-Trifluoroethanol): MAE = 0.1118


Single Fold 2 (2-Methyltetrahydrofuran [2-MeTHF]): MAE = 0.0502

Single solvent quick test MAE: 0.1375

Testing on full data...


Full Fold 0: MAE = 0.0840


Full Fold 1: MAE = 0.1285


Full Fold 2: MAE = 0.0833

Full data quick test MAE: 0.0986


In [8]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModel(data='single') # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

0it [00:00, ?it/s]

1it [00:06,  6.33s/it]

2it [00:12,  6.34s/it]

3it [00:18,  6.22s/it]

4it [00:24,  6.19s/it]

5it [00:31,  6.32s/it]

6it [00:37,  6.36s/it]

7it [00:44,  6.37s/it]

8it [00:50,  6.36s/it]

9it [00:57,  6.38s/it]

10it [01:03,  6.37s/it]

11it [01:09,  6.37s/it]

12it [01:16,  6.37s/it]

13it [01:22,  6.37s/it]

14it [01:28,  6.39s/it]

15it [01:35,  6.40s/it]

16it [01:41,  6.44s/it]

17it [01:48,  6.59s/it]

18it [01:55,  6.58s/it]

19it [02:01,  6.56s/it]

20it [02:08,  6.56s/it]

21it [02:14,  6.54s/it]

22it [02:21,  6.54s/it]

23it [02:27,  6.52s/it]

24it [02:34,  6.52s/it]

24it [02:34,  6.44s/it]




In [9]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModel(data = 'full') # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

0it [00:00, ?it/s]

1it [00:11, 11.37s/it]

2it [00:22, 11.31s/it]

3it [00:34, 11.39s/it]

4it [00:45, 11.32s/it]

5it [00:56, 11.25s/it]

6it [01:07, 11.22s/it]

7it [01:18, 11.20s/it]

8it [01:30, 11.23s/it]

9it [01:41, 11.23s/it]

10it [01:53, 11.53s/it]

11it [02:05, 11.73s/it]

12it [02:17, 11.86s/it]

13it [02:29, 11.94s/it]

13it [02:29, 11.53s/it]




In [10]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################