# Experiment 086: Hybrid GNN (GAT + DRFP + Mixture-Aware)

**Based on arXiv:2512.19530**: The benchmark paper achieved MSE 0.0039 using a hybrid GNN architecture.

**Architecture**:
1. **Graph Attention Network (GAT)** - for molecular graph message-passing on solvent SMILES
2. **DRFP encoder** - for reaction fingerprints (pre-computed)
3. **Mixture-aware encoding** - for continuous solvent mixture representation
4. **Kinetic features** - Arrhenius-style temperature/time features

**Key insight**: Our previous GNN attempts (CV 0.024-0.026) failed because they didn't use the hybrid architecture with DRFP and mixture-aware encoding.

**CRITICAL**: The SAME model class must be used in both CV computation AND submission cells!

In [None]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"}')

In [None]:
# PyTorch Geometric imports
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import Data, Batch
from rdkit import Chem
from rdkit.Chem import AllChem

print('PyTorch Geometric loaded successfully')

In [None]:
# Data loading functions
DATA_PATH = '/home/data'

INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[INPUT_LABELS_FULL_SOLVENT]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[INPUT_LABELS_SINGLE_SOLVENT]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

In [None]:
# Load SMILES and feature lookups
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv')
SMILES_DICT = dict(zip(SMILES_DF['SOLVENT NAME'], SMILES_DF['solvent smiles']))

DRFP_DF = pd.read_csv(f'{DATA_PATH}/drfps_catechol_lookup.csv', index_col=0)
drfp_variance = DRFP_DF.var()
nonzero_variance_cols = drfp_variance[drfp_variance > 0].index.tolist()
DRFP_FILTERED = DRFP_DF[nonzero_variance_cols]

SPANGE_DF = pd.read_csv(f'{DATA_PATH}/spange_descriptors_lookup.csv', index_col=0)

print(f'SMILES: {len(SMILES_DICT)} solvents')
print(f'DRFP filtered: {DRFP_FILTERED.shape}')
print(f'Spange: {SPANGE_DF.shape}')

In [None]:
# Convert SMILES to molecular graph
def smiles_to_graph(smiles):
    """Convert SMILES to PyTorch Geometric graph."""
    # Handle mixture SMILES (e.g., 'O.CC#N')
    if '.' in smiles:
        parts = smiles.split('.')
        smiles = parts[0]  # Use first component
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        # Fallback: create a simple graph
        return Data(
            x=torch.zeros(1, 9, dtype=torch.float),
            edge_index=torch.zeros(2, 0, dtype=torch.long),
            edge_attr=torch.zeros(0, 3, dtype=torch.float)
        )
    
    # Node features: atomic number, degree, formal charge, hybridization, etc.
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic()),
            atom.GetTotalNumHs(),
            atom.GetNumRadicalElectrons(),
            int(atom.IsInRing()),
            atom.GetMass() / 100.0  # Normalized mass
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Edge features
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_index.extend([[i, j], [j, i]])
        bond_features = [
            float(bond.GetBondTypeAsDouble()),
            int(bond.GetIsConjugated()),
            int(bond.IsInRing())
        ]
        edge_attr.extend([bond_features, bond_features])
    
    if len(edge_index) == 0:
        edge_index = torch.zeros(2, 0, dtype=torch.long)
        edge_attr = torch.zeros(0, 3, dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Pre-compute graphs for all solvents
SOLVENT_GRAPHS = {}
for name, smiles in SMILES_DICT.items():
    SOLVENT_GRAPHS[name] = smiles_to_graph(smiles)

print(f'Pre-computed graphs for {len(SOLVENT_GRAPHS)} solvents')

In [None]:
# Hybrid GNN Model
class HybridGNNModel(nn.Module):
    def __init__(self, node_dim=9, drfp_dim=122, spange_dim=13, hidden_dim=64, num_heads=4):
        super().__init__()
        
        # GAT for molecular graphs
        self.gat1 = GATConv(node_dim, hidden_dim, heads=num_heads, concat=True)
        self.gat2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, concat=False)
        
        # DRFP encoder
        self.drfp_encoder = nn.Sequential(
            nn.Linear(drfp_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Spange encoder (physicochemical properties)
        self.spange_encoder = nn.Sequential(
            nn.Linear(spange_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Kinetic features encoder (time, temperature)
        self.kinetic_encoder = nn.Sequential(
            nn.Linear(5, hidden_dim),  # time, temp, inv_temp, log_time, interaction
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Mixture encoder (for solvent mixtures)
        self.mixture_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2 + 1, hidden_dim),  # 2 solvent embeddings + ratio
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Final predictor
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim * 4, 128),  # GAT + DRFP + Spange + Kinetic
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 3)  # 3 targets: P2, P3, SM
        )
    
    def encode_solvent(self, graph_batch):
        """Encode solvent using GAT."""
        x, edge_index, batch = graph_batch.x, graph_batch.edge_index, graph_batch.batch
        
        # GAT layers
        x = F.relu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        return x
    
    def forward(self, graph_batch, drfp, spange, kinetic, mixture_ratio=None, graph_batch_b=None, drfp_b=None, spange_b=None):
        # Encode solvent A
        gat_out_a = self.encode_solvent(graph_batch)
        drfp_out_a = self.drfp_encoder(drfp)
        spange_out_a = self.spange_encoder(spange)
        
        if mixture_ratio is not None and graph_batch_b is not None:
            # Mixture case: encode solvent B and combine
            gat_out_b = self.encode_solvent(graph_batch_b)
            drfp_out_b = self.drfp_encoder(drfp_b)
            spange_out_b = self.spange_encoder(spange_b)
            
            # Mixture-aware encoding
            ratio = mixture_ratio.unsqueeze(1) if mixture_ratio.dim() == 1 else mixture_ratio
            gat_combined = torch.cat([gat_out_a, gat_out_b, ratio], dim=1)
            gat_out = self.mixture_encoder(gat_combined)
            
            # Linear interpolation for DRFP and Spange
            drfp_out = (1 - ratio) * drfp_out_a + ratio * drfp_out_b
            spange_out = (1 - ratio) * spange_out_a + ratio * spange_out_b
        else:
            # Single solvent case
            gat_out = gat_out_a
            drfp_out = drfp_out_a
            spange_out = spange_out_a
        
        # Encode kinetic features
        kinetic_out = self.kinetic_encoder(kinetic)
        
        # Combine all features
        combined = torch.cat([gat_out, drfp_out, spange_out, kinetic_out], dim=1)
        
        # Predict
        return self.predictor(combined)

print('HybridGNNModel defined')

In [None]:
# Wrapper class for training and prediction (compatible with submission template)
class HybridGNNWrapper:
    def __init__(self, data='single', hidden_dim=64, num_epochs=100, lr=0.001):
        self.data = data
        self.mixed = (data == 'full')
        self.hidden_dim = hidden_dim
        self.num_epochs = num_epochs
        self.lr = lr
        self.model = None
        self.scaler = StandardScaler()
        self.drfp_scaler = StandardScaler()
        self.spange_scaler = StandardScaler()
        
    def _prepare_features(self, X, Y=None, fit_scalers=False):
        """Prepare features for the model."""
        # Kinetic features
        time_m = X["Residence Time"].values.reshape(-1, 1)
        temp_c = X["Temperature"].values.reshape(-1, 1)
        temp_k = temp_c + 273.15
        inv_temp = 1000.0 / temp_k
        log_time = np.log(time_m + 1e-6)
        interaction = inv_temp * log_time
        kinetic = np.hstack([time_m, temp_c, inv_temp, log_time, interaction])
        
        if self.mixed:
            # Get solvent names
            solvent_a = X["SOLVENT A NAME"].values
            solvent_b = X["SOLVENT B NAME"].values
            mixture_ratio = X["SolventB%"].values
            
            # Get DRFP and Spange features
            drfp_a = DRFP_FILTERED.loc[solvent_a].values
            drfp_b = DRFP_FILTERED.loc[solvent_b].values
            spange_a = SPANGE_DF.loc[solvent_a].values
            spange_b = SPANGE_DF.loc[solvent_b].values
            
            # Get graphs
            graphs_a = [SOLVENT_GRAPHS[s] for s in solvent_a]
            graphs_b = [SOLVENT_GRAPHS[s] for s in solvent_b]
            
            if fit_scalers:
                self.drfp_scaler.fit(np.vstack([drfp_a, drfp_b]))
                self.spange_scaler.fit(np.vstack([spange_a, spange_b]))
            
            drfp_a = self.drfp_scaler.transform(drfp_a)
            drfp_b = self.drfp_scaler.transform(drfp_b)
            spange_a = self.spange_scaler.transform(spange_a)
            spange_b = self.spange_scaler.transform(spange_b)
            
            return {
                'kinetic': kinetic,
                'drfp_a': drfp_a, 'drfp_b': drfp_b,
                'spange_a': spange_a, 'spange_b': spange_b,
                'graphs_a': graphs_a, 'graphs_b': graphs_b,
                'mixture_ratio': mixture_ratio
            }
        else:
            # Single solvent
            solvent = X["SOLVENT NAME"].values
            drfp = DRFP_FILTERED.loc[solvent].values
            spange = SPANGE_DF.loc[solvent].values
            graphs = [SOLVENT_GRAPHS[s] for s in solvent]
            
            if fit_scalers:
                self.drfp_scaler.fit(drfp)
                self.spange_scaler.fit(spange)
            
            drfp = self.drfp_scaler.transform(drfp)
            spange = self.spange_scaler.transform(spange)
            
            return {
                'kinetic': kinetic,
                'drfp': drfp,
                'spange': spange,
                'graphs': graphs
            }
    
    def train_model(self, X, Y):
        """Train the hybrid GNN model."""
        # Prepare features
        features = self._prepare_features(X, Y, fit_scalers=True)
        Y_vals = Y.values
        
        # Scale targets
        self.scaler.fit(Y_vals)
        Y_scaled = self.scaler.transform(Y_vals)
        
        # Initialize model
        self.model = HybridGNNModel(
            node_dim=9,
            drfp_dim=DRFP_FILTERED.shape[1],
            spange_dim=SPANGE_DF.shape[1],
            hidden_dim=self.hidden_dim
        ).to(device)
        
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.MSELoss()
        
        # Training loop
        self.model.train()
        n_samples = len(Y_vals)
        batch_size = min(32, n_samples)
        
        for epoch in range(self.num_epochs):
            indices = np.random.permutation(n_samples)
            total_loss = 0
            
            for i in range(0, n_samples, batch_size):
                batch_idx = indices[i:i+batch_size]
                
                # Prepare batch
                kinetic = torch.tensor(features['kinetic'][batch_idx], dtype=torch.float32).to(device)
                y_batch = torch.tensor(Y_scaled[batch_idx], dtype=torch.float32).to(device)
                
                if self.mixed:
                    drfp_a = torch.tensor(features['drfp_a'][batch_idx], dtype=torch.float32).to(device)
                    drfp_b = torch.tensor(features['drfp_b'][batch_idx], dtype=torch.float32).to(device)
                    spange_a = torch.tensor(features['spange_a'][batch_idx], dtype=torch.float32).to(device)
                    spange_b = torch.tensor(features['spange_b'][batch_idx], dtype=torch.float32).to(device)
                    mixture_ratio = torch.tensor(features['mixture_ratio'][batch_idx], dtype=torch.float32).to(device)
                    
                    graphs_a = Batch.from_data_list([features['graphs_a'][j] for j in batch_idx]).to(device)
                    graphs_b = Batch.from_data_list([features['graphs_b'][j] for j in batch_idx]).to(device)
                    
                    pred = self.model(graphs_a, drfp_a, spange_a, kinetic, mixture_ratio, graphs_b, drfp_b, spange_b)
                else:
                    drfp = torch.tensor(features['drfp'][batch_idx], dtype=torch.float32).to(device)
                    spange = torch.tensor(features['spange'][batch_idx], dtype=torch.float32).to(device)
                    graphs = Batch.from_data_list([features['graphs'][j] for j in batch_idx]).to(device)
                    
                    pred = self.model(graphs, drfp, spange, kinetic)
                
                loss = criterion(pred, y_batch)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item() * len(batch_idx)
            
            if (epoch + 1) % 50 == 0:
                print(f'Epoch {epoch+1}/{self.num_epochs}, Loss: {total_loss/n_samples:.6f}')
    
    def predict(self, X):
        """Make predictions."""
        features = self._prepare_features(X, fit_scalers=False)
        
        self.model.eval()
        n_samples = len(X)
        
        with torch.no_grad():
            kinetic = torch.tensor(features['kinetic'], dtype=torch.float32).to(device)
            
            if self.mixed:
                drfp_a = torch.tensor(features['drfp_a'], dtype=torch.float32).to(device)
                drfp_b = torch.tensor(features['drfp_b'], dtype=torch.float32).to(device)
                spange_a = torch.tensor(features['spange_a'], dtype=torch.float32).to(device)
                spange_b = torch.tensor(features['spange_b'], dtype=torch.float32).to(device)
                mixture_ratio = torch.tensor(features['mixture_ratio'], dtype=torch.float32).to(device)
                
                graphs_a = Batch.from_data_list(features['graphs_a']).to(device)
                graphs_b = Batch.from_data_list(features['graphs_b']).to(device)
                
                pred = self.model(graphs_a, drfp_a, spange_a, kinetic, mixture_ratio, graphs_b, drfp_b, spange_b)
            else:
                drfp = torch.tensor(features['drfp'], dtype=torch.float32).to(device)
                spange = torch.tensor(features['spange'], dtype=torch.float32).to(device)
                graphs = Batch.from_data_list(features['graphs']).to(device)
                
                pred = self.model(graphs, drfp, spange, kinetic)
            
            pred_np = pred.cpu().numpy()
            pred_unscaled = self.scaler.inverse_transform(pred_np)
        
        return torch.tensor(pred_unscaled, dtype=torch.float64)

print('HybridGNNWrapper defined')

In [None]:
# Quick test to verify the model works
print('Testing model on a small batch...')
X, Y = load_data("single_solvent")
X_small = X.head(10)
Y_small = Y.head(10)

model = HybridGNNWrapper(data='single', num_epochs=10)
model.train_model(X_small, Y_small)
preds = model.predict(X_small)
print(f'Predictions shape: {preds.shape}')
print(f'Sample predictions: {preds[:3]}')

In [None]:
# Cross-validation on single solvent data
import tqdm

X, Y = load_data("single_solvent")
print(f'Single solvent data: {len(X)} samples, {len(X["SOLVENT NAME"].unique())} solvents')

all_mse = []
for fold_idx, split in tqdm.tqdm(enumerate(generate_leave_one_out_splits(X, Y))):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = HybridGNNWrapper(data='single', hidden_dim=64, num_epochs=100, lr=0.001)
    model.train_model(train_X, train_Y)
    
    preds = model.predict(test_X).numpy()
    targets = test_Y.values
    
    mse = np.mean((preds - targets) ** 2)
    all_mse.append(mse)
    
    if (fold_idx + 1) % 6 == 0:
        print(f'Fold {fold_idx+1}/24, MSE: {mse:.6f}, Running avg: {np.mean(all_mse):.6f}')

single_mse = np.mean(all_mse)
print(f'\nSingle Solvent MSE: {single_mse:.6f}')

In [None]:
# Cross-validation on full data
X_full, Y_full = load_data("full")
print(f'Full data: {len(X_full)} samples')

all_mse_full = []
for fold_idx, split in tqdm.tqdm(enumerate(generate_leave_one_ramp_out_splits(X_full, Y_full))):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = HybridGNNWrapper(data='full', hidden_dim=64, num_epochs=100, lr=0.001)
    model.train_model(train_X, train_Y)
    
    preds = model.predict(test_X).numpy()
    targets = test_Y.values
    
    mse = np.mean((preds - targets) ** 2)
    all_mse_full.append(mse)

full_mse = np.mean(all_mse_full)
print(f'\nFull Data MSE: {full_mse:.6f}')

In [None]:
# Calculate overall CV score
n_single = 656
n_full = 1227
overall_mse = (single_mse * n_single + full_mse * n_full) / (n_single + n_full)

print(f'\n=== Hybrid GNN Results ===')
print(f'Single Solvent MSE: {single_mse:.6f}')
print(f'Full Data MSE: {full_mse:.6f}')
print(f'Overall MSE: {overall_mse:.6f}')
print(f'\nBest tabular baseline: 0.008298')
print(f'Difference: {(overall_mse - 0.008298) / 0.008298 * 100:.2f}%')
print(f'\nBenchmark paper target: 0.0039')

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = HybridGNNWrapper(data='single')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = HybridGNNWrapper(data='full')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################