# Experiment 085: Graph Neural Network (GNN) for Solvent Prediction

**Rationale**: The GNN benchmark achieved 0.0039 CV - much better than our best 0.0081. GNNs may change the CV-LB relationship because:
- They learn molecular STRUCTURE directly via message-passing
- They have inductive bias that helps with extrapolation to unseen molecules
- They don't rely on pre-computed features that may not generalize

**Implementation**:
- Use PyTorch Geometric
- Use RDKit to convert SMILES to molecular graphs
- Simple GCN architecture with 3 layers
- Train on same LOO-CV scheme

In [1]:
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from rdkit import Chem
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool
import warnings
warnings.filterwarnings('ignore')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
print('Imports done')

Device: cuda
Imports done


In [2]:
# Load data
def load_data(data_type):
    if data_type == "single_solvent":
        df = pd.read_csv('/home/data/catechol_single_solvent_yields.csv')
        X = df[['Residence Time', 'Temperature', 'SOLVENT NAME']]
        Y = df[['SM', 'Product 2', 'Product 3']]
    elif data_type == "full":
        df = pd.read_csv('/home/data/catechol_full_data_yields.csv')
        X = df[['Residence Time', 'Temperature', 'SOLVENT A NAME', 'SOLVENT B NAME', 'SolventB%']]
        Y = df[['SM', 'Product 2', 'Product 3']]
    return X, Y

# Load SMILES lookup
smiles_df = pd.read_csv('/home/data/smiles_lookup.csv')
smiles_dict = dict(zip(smiles_df['SOLVENT NAME'], smiles_df['solvent smiles']))
print(f"Loaded {len(smiles_dict)} SMILES")
print("Sample:", list(smiles_dict.items())[:3])

Loaded 26 SMILES
Sample: [('Cyclohexane', 'C1CCCCC1'), ('Ethyl Acetate', 'O=C(OCC)C'), ('Acetic Acid', 'CC(=O)O')]


In [3]:
# Official CV split functions (DO NOT MODIFY)
from typing import Any, Generator

def generate_leave_one_out_splits(
    X: pd.DataFrame, Y: pd.DataFrame
) -> Generator[
    tuple[tuple[pd.DataFrame, pd.DataFrame], tuple[pd.DataFrame, pd.DataFrame]],
    Any,
    None,
]:
    for solvent in X["SOLVENT NAME"].unique():
        train_mask = X["SOLVENT NAME"] != solvent
        test_mask = X["SOLVENT NAME"] == solvent
        yield (
            (X[train_mask], Y[train_mask]),
            (X[test_mask], Y[test_mask]),
        )

def generate_leave_one_ramp_out_splits(
    X: pd.DataFrame, Y: pd.DataFrame
) -> Generator[
    tuple[tuple[pd.DataFrame, pd.DataFrame], tuple[pd.DataFrame, pd.DataFrame]],
    Any,
    None,
]:
    ramps = X["SOLVENT A NAME"].astype(str) + "_" + X["SOLVENT B NAME"].astype(str)
    for ramp in ramps.unique():
        train_mask = ramps != ramp
        test_mask = ramps == ramp
        yield (
            (X[train_mask], Y[train_mask]),
            (X[test_mask], Y[test_mask]),
        )

print('CV split functions defined')

CV split functions defined


In [4]:
# Convert SMILES to molecular graph
def smiles_to_graph(smiles):
    """Convert SMILES string to PyTorch Geometric Data object."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Atom features
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),  # Atomic number
            atom.GetDegree(),  # Number of bonds
            atom.GetFormalCharge(),  # Formal charge
            int(atom.GetHybridization()),  # Hybridization
            int(atom.GetIsAromatic()),  # Is aromatic
            atom.GetTotalNumHs(),  # Number of hydrogens
            int(atom.IsInRing()),  # Is in ring
        ]
        atom_features.append(features)
    
    # Edge index (bonds)
    edge_index = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])  # Undirected
    
    if len(edge_index) == 0:
        # Single atom molecule - add self-loop
        edge_index = [[0, 0]]
    
    x = torch.tensor(atom_features, dtype=torch.float)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    
    return Data(x=x, edge_index=edge_index)

# Test
test_graph = smiles_to_graph('CCO')  # Ethanol
print(f"Ethanol graph: {test_graph.num_nodes} nodes, {test_graph.num_edges} edges")
print(f"Node features shape: {test_graph.x.shape}")

Ethanol graph: 3 nodes, 4 edges
Node features shape: torch.Size([3, 7])


In [5]:
# Pre-compute all solvent graphs
solvent_graphs = {}
for name, smiles in smiles_dict.items():
    graph = smiles_to_graph(smiles)
    if graph is not None:
        solvent_graphs[name] = graph
        print(f"{name}: {graph.num_nodes} atoms, {graph.num_edges} edges")
    else:
        print(f"WARNING: Could not parse {name}: {smiles}")

print(f"\nTotal: {len(solvent_graphs)} solvent graphs")

Cyclohexane: 6 atoms, 12 edges
Ethyl Acetate: 6 atoms, 10 edges
Acetic Acid: 4 atoms, 6 edges
2-Methyltetrahydrofuran [2-MeTHF]: 6 atoms, 12 edges
1,1,1,3,3,3-Hexafluoropropan-2-ol: 10 atoms, 18 edges
IPA [Propan-2-ol]: 4 atoms, 6 edges
Ethanol: 3 atoms, 4 edges
Methanol: 2 atoms, 2 edges
Ethylene Glycol [1,2-Ethanediol]: 4 atoms, 6 edges
Acetonitrile: 3 atoms, 4 edges
Water: 1 atoms, 1 edges
Diethyl Ether [Ether]: 5 atoms, 8 edges
MTBE [tert-Butylmethylether]: 6 atoms, 10 edges
Dimethyl Carbonate: 6 atoms, 10 edges
tert-Butanol [2-Methylpropan-2-ol]: 5 atoms, 8 edges
DMA [N,N-Dimethylacetamide]: 6 atoms, 10 edges
2,2,2-Trifluoroethanol: 6 atoms, 10 edges
Dihydrolevoglucosenone (Cyrene): 9 atoms, 20 edges
Decanol: 11 atoms, 20 edges
Butanone [MEK]: 5 atoms, 8 edges
Ethyl Lactate: 8 atoms, 14 edges
Methyl Propionate: 6 atoms, 10 edges
THF [Tetrahydrofuran]: 5 atoms, 10 edges
Water.Acetonitrile: 4 atoms, 4 edges
Acetonitrile.Acetic Acid: 7 atoms, 10 edges
Water.2,2,2-Trifluoroethanol: 7 

In [6]:
# GNN Model
class SolventGNN(nn.Module):
    def __init__(self, in_channels=7, hidden_channels=64, out_channels=3):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        
        # +2 for Temperature and Residence Time
        self.lin = nn.Sequential(
            nn.Linear(hidden_channels + 2, hidden_channels),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_channels // 2, out_channels),
            nn.Sigmoid()  # Output in [0, 1]
        )
    
    def forward(self, data, T, RT):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Message passing
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        
        # Graph-level pooling
        x = global_mean_pool(x, batch)
        
        # Concatenate with reaction conditions
        x = torch.cat([x, T.unsqueeze(1), RT.unsqueeze(1)], dim=1)
        
        return self.lin(x)

print('SolventGNN defined')

SolventGNN defined


In [7]:
# GNN Model wrapper for single solvent
class GNNModel:
    def __init__(self, data='single', hidden_channels=64, num_epochs=200, lr=1e-3):
        self.data_type = data
        self.mixed = (data == 'full')
        self.hidden_channels = hidden_channels
        self.num_epochs = num_epochs
        self.lr = lr
        self.model = None
        
    def train_model(self, train_X, train_Y):
        # Prepare data
        graphs = []
        temps = []
        rts = []
        targets = []
        
        for i in range(len(train_X)):
            row = train_X.iloc[i]
            if self.mixed:
                # For mixed solvents, use weighted combination of graphs
                # For simplicity, use the dominant solvent's graph
                sb_pct = row['SolventB%'] / 100.0
                if sb_pct < 0.5:
                    solvent_name = row['SOLVENT A NAME']
                else:
                    solvent_name = row['SOLVENT B NAME']
            else:
                solvent_name = row['SOLVENT NAME']
            
            if solvent_name in solvent_graphs:
                graphs.append(solvent_graphs[solvent_name].clone())
                temps.append(row['Temperature'])
                rts.append(row['Residence Time'])
                targets.append(train_Y.iloc[i].values)
        
        if len(graphs) == 0:
            print("WARNING: No valid graphs found!")
            return
        
        # Normalize temperature and residence time
        temps = torch.tensor(temps, dtype=torch.float)
        rts = torch.tensor(rts, dtype=torch.float)
        targets = torch.tensor(np.array(targets), dtype=torch.float)
        
        self.temp_mean, self.temp_std = temps.mean(), temps.std() + 1e-6
        self.rt_mean, self.rt_std = rts.mean(), rts.std() + 1e-6
        
        temps = (temps - self.temp_mean) / self.temp_std
        rts = (rts - self.rt_mean) / self.rt_std
        
        # Create model
        self.model = SolventGNN(in_channels=7, hidden_channels=self.hidden_channels).to(device)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.MSELoss()
        
        # Training loop
        self.model.train()
        batch_size = min(32, len(graphs))
        
        for epoch in range(self.num_epochs):
            # Shuffle data
            perm = torch.randperm(len(graphs))
            total_loss = 0
            
            for start in range(0, len(graphs), batch_size):
                end = min(start + batch_size, len(graphs))
                batch_idx = perm[start:end]
                
                batch_graphs = [graphs[i] for i in batch_idx]
                batch = Batch.from_data_list(batch_graphs).to(device)
                batch_temps = temps[batch_idx].to(device)
                batch_rts = rts[batch_idx].to(device)
                batch_targets = targets[batch_idx].to(device)
                
                optimizer.zero_grad()
                output = self.model(batch, batch_temps, batch_rts)
                loss = criterion(output, batch_targets)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item() * len(batch_idx)
            
            # if (epoch + 1) % 50 == 0:
            #     print(f"Epoch {epoch+1}/{self.num_epochs}, Loss: {total_loss/len(graphs):.6f}")
    
    def predict(self, test_X):
        if self.model is None:
            raise RuntimeError("Model not trained")
        
        self.model.eval()
        
        graphs = []
        temps = []
        rts = []
        valid_indices = []
        
        for i in range(len(test_X)):
            row = test_X.iloc[i]
            if self.mixed:
                sb_pct = row['SolventB%'] / 100.0
                if sb_pct < 0.5:
                    solvent_name = row['SOLVENT A NAME']
                else:
                    solvent_name = row['SOLVENT B NAME']
            else:
                solvent_name = row['SOLVENT NAME']
            
            if solvent_name in solvent_graphs:
                graphs.append(solvent_graphs[solvent_name].clone())
                temps.append(row['Temperature'])
                rts.append(row['Residence Time'])
                valid_indices.append(i)
        
        if len(graphs) == 0:
            # Return default predictions
            return torch.zeros(len(test_X), 3)
        
        temps = torch.tensor(temps, dtype=torch.float)
        rts = torch.tensor(rts, dtype=torch.float)
        
        temps = (temps - self.temp_mean) / self.temp_std
        rts = (rts - self.rt_mean) / self.rt_std
        
        with torch.no_grad():
            batch = Batch.from_data_list(graphs).to(device)
            temps = temps.to(device)
            rts = rts.to(device)
            
            output = self.model(batch, temps, rts)
            output = output.cpu()
        
        # Fill in predictions for all samples
        predictions = torch.zeros(len(test_X), 3)
        for j, i in enumerate(valid_indices):
            predictions[i] = output[j]
        
        return predictions

print('GNNModel defined')

GNNModel defined


In [8]:
# Run CV for single solvent data
import tqdm

X, Y = load_data("single_solvent")
print(f"Single solvent data: {len(X)} samples, {len(X['SOLVENT NAME'].unique())} solvents")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []
fold_mses = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=24):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = GNNModel(data='single', hidden_channels=64, num_epochs=150, lr=1e-3)
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X)
    predictions_np = predictions.detach().cpu().numpy()
    
    # Calculate fold MSE
    fold_mse = np.mean((predictions_np - test_Y.values) ** 2)
    fold_mses.append(fold_mse)
    
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

submission_single_solvent = pd.DataFrame(all_predictions)
print(f"\nSingle solvent CV MSE: {np.mean(fold_mses):.6f} ± {np.std(fold_mses):.6f}")

Single solvent data: 656 samples, 24 solvents


  0%|          | 0/24 [00:00<?, ?it/s]

  4%|▍         | 1/24 [00:10<03:52, 10.12s/it]

  8%|▊         | 2/24 [00:19<03:38,  9.95s/it]

 12%|█▎        | 3/24 [00:29<03:27,  9.87s/it]

 17%|█▋        | 4/24 [00:39<03:14,  9.74s/it]

 21%|██        | 5/24 [00:49<03:05,  9.77s/it]

 25%|██▌       | 6/24 [00:59<02:58,  9.94s/it]

 29%|██▉       | 7/24 [01:09<02:47,  9.88s/it]

 33%|███▎      | 8/24 [01:18<02:35,  9.70s/it]

 38%|███▊      | 9/24 [01:28<02:25,  9.72s/it]

 42%|████▏     | 10/24 [01:37<02:16,  9.74s/it]

 46%|████▌     | 11/24 [01:47<02:07,  9.78s/it]

 50%|█████     | 12/24 [01:57<01:57,  9.81s/it]

 54%|█████▍    | 13/24 [02:07<01:47,  9.79s/it]

 58%|█████▊    | 14/24 [02:17<01:37,  9.80s/it]

 62%|██████▎   | 15/24 [02:27<01:28,  9.78s/it]

 67%|██████▋   | 16/24 [02:36<01:18,  9.78s/it]

 71%|███████   | 17/24 [02:46<01:08,  9.80s/it]

 75%|███████▌  | 18/24 [02:56<00:58,  9.80s/it]

 79%|███████▉  | 19/24 [03:06<00:48,  9.79s/it]

 83%|████████▎ | 20/24 [03:15<00:39,  9.78s/it]

 88%|████████▊ | 21/24 [03:25<00:29,  9.78s/it]

 92%|█████████▏| 22/24 [03:35<00:19,  9.78s/it]

 96%|█████████▌| 23/24 [03:45<00:09,  9.78s/it]

100%|██████████| 24/24 [03:55<00:00,  9.77s/it]

100%|██████████| 24/24 [03:55<00:00,  9.79s/it]


Single solvent CV MSE: 0.026536 ± 0.025804





In [9]:
# Run CV for full (mixture) data
X, Y = load_data("full")
print(f"Full data: {len(X)} samples")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []
fold_mses = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=13):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = GNNModel(data='full', hidden_channels=64, num_epochs=150, lr=1e-3)
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X)
    predictions_np = predictions.detach().cpu().numpy()
    
    # Calculate fold MSE
    fold_mse = np.mean((predictions_np - test_Y.values) ** 2)
    fold_mses.append(fold_mse)
    
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

submission_full_data = pd.DataFrame(all_predictions)
print(f"\nFull data CV MSE: {np.mean(fold_mses):.6f} ± {np.std(fold_mses):.6f}")

Full data: 1227 samples


  0%|          | 0/13 [00:00<?, ?it/s]

  8%|▊         | 1/13 [00:17<03:28, 17.34s/it]

 15%|█▌        | 2/13 [00:34<03:11, 17.43s/it]

 23%|██▎       | 3/13 [00:52<02:55, 17.59s/it]

 31%|███       | 4/13 [01:09<02:36, 17.39s/it]

 38%|███▊      | 5/13 [01:26<02:18, 17.29s/it]

 46%|████▌     | 6/13 [01:43<02:00, 17.22s/it]

 54%|█████▍    | 7/13 [02:00<01:43, 17.18s/it]

 62%|██████▏   | 8/13 [02:18<01:25, 17.16s/it]

 69%|██████▉   | 9/13 [02:35<01:08, 17.18s/it]

 77%|███████▋  | 10/13 [02:53<00:52, 17.58s/it]

 85%|████████▍ | 11/13 [03:12<00:35, 17.89s/it]

 92%|█████████▏| 12/13 [03:31<00:18, 18.13s/it]

100%|██████████| 13/13 [03:49<00:00, 18.25s/it]

100%|██████████| 13/13 [03:49<00:00, 17.66s/it]


Full data CV MSE: 0.016705 ± 0.013263





In [10]:
# Combine and save submission
submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"

print(f"Submission shape: {submission.shape}")

# Save
submission.to_csv("/home/submission/submission.csv", index=True)
print(f"\nSubmission saved to /home/submission/submission.csv")

# Verify
submission_check = pd.read_csv("/home/submission/submission.csv")
print(f"\nSubmission rows: {len(submission_check)}")

# Check prediction ranges
target_cols = ['target_1', 'target_2', 'target_3']
for col in target_cols:
    print(f"{col}: min={submission_check[col].min():.4f}, max={submission_check[col].max():.4f}")

Submission shape: (1883, 7)

Submission saved to /home/submission/submission.csv

Submission rows: 1883
target_1: min=0.0010, max=0.9855
target_2: min=0.0000, max=0.4299
target_3: min=0.0000, max=0.4490


In [11]:
# Calculate overall CV score
print("="*50)
print("EXPERIMENT 085 COMPLETE")
print("="*50)
print(f"\nKey techniques:")
print("1. Graph Neural Network (GCN) for molecular representation")
print("2. RDKit for SMILES to graph conversion")
print("3. Message-passing to learn molecular structure")
print("4. Global mean pooling for graph-level representation")
print("5. Concatenate with Temperature and Residence Time")
print("\nThis approach learns molecular STRUCTURE directly, which may help with extrapolation to unseen solvents.")

EXPERIMENT 085 COMPLETE

Key techniques:
1. Graph Neural Network (GCN) for molecular representation
2. RDKit for SMILES to graph conversion
3. Message-passing to learn molecular structure
4. Global mean pooling for graph-level representation
5. Concatenate with Temperature and Residence Time

This approach learns molecular STRUCTURE directly, which may help with extrapolation to unseen solvents.


In [None]:
# Calculate overall CV score
single_cv = 0.026536
full_cv = 0.016705

# Weighted by sample count
total_samples = 656 + 1227
overall_cv = (656 * single_cv + 1227 * full_cv) / total_samples

print(f"Single solvent CV: {single_cv:.6f}")
print(f"Full data CV: {full_cv:.6f}")
print(f"Overall CV (sample-weighted): {overall_cv:.6f}")

print("\n" + "="*50)
print("COMPARISON WITH PREVIOUS RESULTS")
print("="*50)
print(f"This experiment (GNN): {overall_cv:.6f}")
print(f"Best previous CV (Leave-One-Out): 0.008092 (exp_049)")
print(f"Best verified LB: 0.0877 (exp_030, exp_067)")

print("\n" + "="*50)
print("ANALYSIS")
print("="*50)
print("The GNN achieved CV=0.0201, which is MUCH WORSE than our best (0.0081)")
print("Possible reasons:")
print("1. Simple GCN architecture may not be powerful enough")
print("2. Only 150 epochs may not be enough training")
print("3. For mixed solvents, using dominant solvent only loses information")
print("4. The GNN benchmark may have used more sophisticated architecture")
print("5. Small dataset (656 samples) may not be enough to train GNN effectively")