# Proper GNN Implementation with PyTorch Geometric

**Hypothesis**: The benchmark paper achieved MSE 0.0039 with GNN. Our previous GNN attempts achieved CV 0.025+ (3x worse than tabular). This suggests implementation issues.

**Key differences from previous attempts**:
1. Use PyTorch Geometric's `from_smiles` utility for proper graph construction
2. Use GCNConv layers with proper message passing
3. Use global_mean_pool for graph-level readout
4. Combine with process features (T, RT, kinetics)
5. VERIFY submission cells use the SAME model class as CV

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import tqdm
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# PyTorch Geometric imports
from torch_geometric.utils import from_smiles
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader

print('PyTorch Geometric imports successful')

PyTorch Geometric imports successful


In [3]:
# Data loading functions
DATA_PATH = '/home/data'

INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[INPUT_LABELS_FULL_SOLVENT]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[INPUT_LABELS_SINGLE_SOLVENT]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

Data loading functions defined


In [4]:
# Load SMILES lookup
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f'SMILES lookup shape: {SMILES_DF.shape}')
print(SMILES_DF.head())

SMILES lookup shape: (26, 1)
                                           solvent smiles
SOLVENT NAME                                             
Cyclohexane                                      C1CCCCC1
Ethyl Acetate                                   O=C(OCC)C
Acetic Acid                                       CC(=O)O
2-Methyltetrahydrofuran [2-MeTHF]              O1C(C)CCC1
1,1,1,3,3,3-Hexafluoropropan-2-ol  C(C(F)(F)F)(C(F)(F)F)O


In [5]:
# Test from_smiles on a sample SMILES
test_smiles = 'CCO'  # Ethanol
test_graph = from_smiles(test_smiles)
print(f'Test graph for {test_smiles}:')
print(f'  x (atom features): {test_graph.x.shape}')
print(f'  edge_index: {test_graph.edge_index.shape}')
print(f'  edge_attr: {test_graph.edge_attr.shape if test_graph.edge_attr is not None else None}')
print(f'  Atom features (first atom): {test_graph.x[0]}')

Test graph for CCO:
  x (atom features): torch.Size([3, 9])
  edge_index: torch.Size([2, 4])
  edge_attr: torch.Size([4, 3])
  Atom features (first atom): tensor([6, 0, 4, 5, 3, 0, 4, 0, 0])


In [6]:
# Pre-compute molecular graphs for all solvents
print('Pre-computing molecular graphs for all solvents...')
SOLVENT_GRAPHS = {}
for solvent_name, row in SMILES_DF.iterrows():
    smiles = row['solvent smiles']
    try:
        graph = from_smiles(smiles)
        SOLVENT_GRAPHS[solvent_name] = graph
        print(f'  {solvent_name}: {smiles[:30]}... -> {graph.x.shape[0]} atoms, {graph.edge_index.shape[1]} edges')
    except Exception as e:
        print(f'  ERROR for {solvent_name}: {e}')

print(f'\nTotal solvents with graphs: {len(SOLVENT_GRAPHS)}')

Pre-computing molecular graphs for all solvents...
  Cyclohexane: C1CCCCC1... -> 6 atoms, 12 edges
  Ethyl Acetate: O=C(OCC)C... -> 6 atoms, 10 edges
  Acetic Acid: CC(=O)O... -> 4 atoms, 6 edges
  2-Methyltetrahydrofuran [2-MeTHF]: O1C(C)CCC1... -> 6 atoms, 12 edges
  1,1,1,3,3,3-Hexafluoropropan-2-ol: C(C(F)(F)F)(C(F)(F)F)O... -> 10 atoms, 18 edges
  IPA [Propan-2-ol]: CC(O)C... -> 4 atoms, 6 edges
  Ethanol: CCO... -> 3 atoms, 4 edges
  Methanol: CO... -> 2 atoms, 2 edges
  Ethylene Glycol [1,2-Ethanediol]: OCCO... -> 4 atoms, 6 edges
  Acetonitrile: CC#N... -> 3 atoms, 4 edges
  Water: O... -> 1 atoms, 0 edges
  Diethyl Ether [Ether]: CCOCC... -> 5 atoms, 8 edges
  MTBE [tert-Butylmethylether]: CC(C)(C)OC... -> 6 atoms, 10 edges
  Dimethyl Carbonate: COC(=O)OC... -> 6 atoms, 10 edges
  tert-Butanol [2-Methylpropan-2-ol]: CC(C)(C)O... -> 5 atoms, 8 edges
  DMA [N,N-Dimethylacetamide]: CN(C)C(C)=O... -> 6 atoms, 10 edges
  2,2,2-Trifluoroethanol: OCC(F)(F)F... -> 6 atoms, 10 edges
  

In [None]:
# GNN Model
class MolGNN(nn.Module):
    def __init__(self, in_channels=9, hidden_dim=64, out_dim=3):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        # +5 for process features (T, RT, 1/T, ln(RT), interaction)
        self.fc1 = nn.Linear(hidden_dim + 5, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, data, process_features):
        x, edge_index, batch = data.x.float(), data.edge_index, data.batch
        
        # GNN layers
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        
        # Graph-level readout
        x = global_mean_pool(x, batch)
        
        # Concatenate with process features
        x = torch.cat([x, process_features], dim=1)
        
        # MLP head
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))  # Yields are in [0, 1]
        
        return x

print('MolGNN defined')

In [None]:
# GNN Model Wrapper for single solvent data
class GNNModelWrapper:
    def __init__(self, data='single', hidden_dim=64, lr=0.001, epochs=100, batch_size=32):
        self.data = data
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.model = None
        self.train_mean = None
        
    def _prepare_data(self, X, Y=None):
        """Prepare data for GNN."""
        data_list = []
        
        if self.data == 'single':
            for idx in range(len(X)):
                row = X.iloc[idx]
                solvent_name = row['SOLVENT NAME']
                
                # Get molecular graph
                if solvent_name not in SOLVENT_GRAPHS:
                    continue
                graph = SOLVENT_GRAPHS[solvent_name].clone()
                
                # Process features
                T = row['Temperature']
                RT = row['Residence Time']
                T_K = T + 273.15
                inv_T = 1000.0 / T_K
                ln_RT = np.log(RT + 1e-6)
                interaction = inv_T * ln_RT
                
                process_feats = torch.tensor([T, RT, inv_T, ln_RT, interaction], dtype=torch.float)
                graph.process_feats = process_feats
                
                if Y is not None:
                    graph.y = torch.tensor(Y.iloc[idx].values, dtype=torch.float)
                
                data_list.append(graph)
        else:
            # Mixed solvent data
            for idx in range(len(X)):
                row = X.iloc[idx]
                solvent_a = row['SOLVENT A NAME']
                solvent_b = row['SOLVENT B NAME']
                pct_b = row['SolventB%']
                
                # Get molecular graphs for both solvents
                if solvent_a not in SOLVENT_GRAPHS or solvent_b not in SOLVENT_GRAPHS:
                    continue
                
                # For mixtures, we'll use a weighted combination approach
                # Get graph for solvent A (primary)
                graph = SOLVENT_GRAPHS[solvent_a].clone()
                
                # Process features including mixture info
                T = row['Temperature']
                RT = row['Residence Time']
                T_K = T + 273.15
                inv_T = 1000.0 / T_K
                ln_RT = np.log(RT + 1e-6)
                interaction = inv_T * ln_RT
                
                # Include pct_b as a feature
                process_feats = torch.tensor([T, RT, inv_T, ln_RT, pct_b], dtype=torch.float)
                graph.process_feats = process_feats
                
                if Y is not None:
                    graph.y = torch.tensor(Y.iloc[idx].values, dtype=torch.float)
                
                data_list.append(graph)
        
        return data_list
    
    def train_model(self, X, Y):
        """Train the GNN model."""
        self.train_mean = Y.mean().values
        
        # Prepare data
        data_list = self._prepare_data(X, Y)
        if len(data_list) == 0:
            raise ValueError('No valid data samples')
        
        # Create model
        in_channels = data_list[0].x.shape[1]
        self.model = MolGNN(in_channels=in_channels, hidden_dim=self.hidden_dim, out_dim=3).to(device)
        
        # Training setup
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.MSELoss()
        
        # DataLoader
        loader = PyGDataLoader(data_list, batch_size=self.batch_size, shuffle=True)
        
        # Training loop
        self.model.train()
        for epoch in range(self.epochs):
            total_loss = 0
            for batch in loader:
                batch = batch.to(device)
                process_feats = batch.process_feats.view(-1, 5).to(device)
                
                optimizer.zero_grad()
                out = self.model(batch, process_feats)
                loss = criterion(out, batch.y.view(-1, 3))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
    
    def predict(self, X):
        """Make predictions."""
        data_list = self._prepare_data(X)
        if len(data_list) == 0:
            # Return mean predictions if no valid data
            return torch.tensor(np.tile(self.train_mean, (len(X), 1)))
        
        loader = PyGDataLoader(data_list, batch_size=self.batch_size, shuffle=False)
        
        self.model.eval()
        all_preds = []
        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                process_feats = batch.process_feats.view(-1, 5).to(device)
                out = self.model(batch, process_feats)
                all_preds.append(out.cpu())
        
        return torch.cat(all_preds, dim=0)

print('GNNModelWrapper defined')

In [None]:
# Quick test of the model
print('Testing GNN model...')
X_single, Y_single = load_data('single_solvent')
print(f'Single solvent data: X={X_single.shape}, Y={Y_single.shape}')

# Test on a small subset
test_model = GNNModelWrapper(data='single', epochs=5, batch_size=32)
test_model.train_model(X_single.head(100), Y_single.head(100))
test_preds = test_model.predict(X_single.head(10))
print(f'Test predictions shape: {test_preds.shape}')
print(f'Test predictions:\n{test_preds[:3]}')

In [None]:
# Cross-validation for single solvent data
print("="*60)
print("Cross-validation: Single Solvent Data (Leave-One-Out)")
print("="*60)

X_single, Y_single = load_data("single_solvent")
print(f"Single solvent data: X={X_single.shape}, Y={Y_single.shape}")

all_mse_single = []
for (train_X, train_Y), (test_X, test_Y) in tqdm.tqdm(generate_leave_one_out_splits(X_single, Y_single), total=24):
    model = GNNModelWrapper(data='single', epochs=100, batch_size=32, lr=0.001)
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mse = np.mean((preds - test_Y.values) ** 2)
    all_mse_single.append(mse)

mse_single = np.mean(all_mse_single)
print(f"\nSingle Solvent MSE: {mse_single:.6f} (+/- {np.std(all_mse_single):.6f})")

In [None]:
# Cross-validation for full data
print("="*60)
print("Cross-validation: Full Data (Leave-One-Ramp-Out)")
print("="*60)

X_full, Y_full = load_data("full")
print(f"Full data: X={X_full.shape}, Y={Y_full.shape}")

all_mse_full = []
for (train_X, train_Y), (test_X, test_Y) in tqdm.tqdm(generate_leave_one_ramp_out_splits(X_full, Y_full), total=13):
    model = GNNModelWrapper(data='full', epochs=100, batch_size=32, lr=0.001)
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mse = np.mean((preds - test_Y.values) ** 2)
    all_mse_full.append(mse)

mse_full = np.mean(all_mse_full)
print(f"\nFull Data MSE: {mse_full:.6f} (+/- {np.std(all_mse_full):.6f})")

In [None]:
# Calculate overall MSE
N_single = len(X_single)
N_full = len(X_full)
N_total = N_single + N_full

overall_mse = (mse_single * N_single + mse_full * N_full) / N_total

print("="*60)
print("SUMMARY")
print("="*60)
print(f"\nGNN Model:")
print(f"  Single Solvent MSE: {mse_single:.6f}")
print(f"  Full Data MSE: {mse_full:.6f}")
print(f"  Overall MSE: {overall_mse:.6f}")

print(f"\nComparison:")
print(f"  Best GP+MLP+LGBM ensemble (exp_030): 0.008298")
print(f"  This GNN vs Best: {(overall_mse - 0.008298) / 0.008298 * 100:.2f}%")

# Expected LB based on CV-LB relationship
expected_lb = 4.31 * overall_mse + 0.0525
print(f"\nExpected LB (based on CV-LB line): {expected_lb:.4f}")
print(f"Target LB: 0.0347")

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModelWrapper(data='single')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModelWrapper(data='full')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################