# GNN Model with AttentiveFP

**Problem**: CV-LB gap has intercept (0.0525) > target (0.0347). Current approach CANNOT reach target.

**Solution**: GNN learns from molecular STRUCTURE, not IDENTITY. Can generalize to unseen solvents.

**GNN Benchmark**: MSE 0.0039 on this exact dataset (22x better than our best LB!)

**Baseline**: exp_035 CV 0.008194, LB 0.0877

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch Geometric imports
from torch_geometric.data import Data, Batch
from torch_geometric.nn import AttentiveFP, global_mean_pool

# RDKit imports
from rdkit import Chem
from rdkit.Chem import AllChem

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float32)  # GNN works better with float32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# Data loading functions
DATA_PATH = '/home/data'

INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[INPUT_LABELS_FULL_SOLVENT]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[INPUT_LABELS_SINGLE_SOLVENT]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

Data loading functions defined


In [3]:
# Load SMILES lookup
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f'SMILES lookup: {len(SMILES_DF)} solvents')
print(SMILES_DF.head())

SMILES lookup: 26 solvents
                                           solvent smiles
SOLVENT NAME                                             
Cyclohexane                                      C1CCCCC1
Ethyl Acetate                                   O=C(OCC)C
Acetic Acid                                       CC(=O)O
2-Methyltetrahydrofuran [2-MeTHF]              O1C(C)CCC1
1,1,1,3,3,3-Hexafluoropropan-2-ol  C(C(F)(F)F)(C(F)(F)F)O


In [4]:
# SMILES to molecular graph conversion
def smiles_to_graph(smiles):
    """Convert SMILES to PyTorch Geometric Data object.
    
    For mixture SMILES (e.g., 'O.CC#N'), we process the first component.
    """
    # Handle mixture SMILES by taking the first component
    if '.' in smiles:
        smiles = smiles.split('.')[0]
    
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Could not parse SMILES: {smiles}")
    
    # Atom features (6 features)
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            int(atom.GetHybridization()),
            int(atom.GetIsAromatic()),
            atom.GetTotalNumHs(),
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float)
    
    # Edge features (3 features)
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.extend([[i, j], [j, i]])
        
        bond_features = [
            float(bond.GetBondTypeAsDouble()),
            int(bond.GetIsAromatic()),
            int(bond.IsInRing()),
        ]
        edge_attr.extend([bond_features, bond_features])
    
    if len(edge_index) == 0:
        # Single atom molecule (e.g., Water 'O')
        edge_index = [[0, 0]]
        edge_attr = [[0.0, 0, 0]]
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Test the function
test_smiles = SMILES_DF.loc['Ethanol', 'solvent smiles']
print(f"Test SMILES: {test_smiles}")
graph = smiles_to_graph(test_smiles)
print(f"Graph: {graph.num_nodes} nodes, {graph.num_edges} edges")
print(f"Node features shape: {graph.x.shape}")
print(f"Edge features shape: {graph.edge_attr.shape}")

Test SMILES: CCO
Graph: 3 nodes, 4 edges
Node features shape: torch.Size([3, 6])
Edge features shape: torch.Size([4, 3])


In [5]:
# Pre-compute molecular graphs for all solvents
SOLVENT_GRAPHS = {}
for solvent_name in SMILES_DF.index:
    smiles = SMILES_DF.loc[solvent_name, 'solvent smiles']
    try:
        graph = smiles_to_graph(smiles)
        SOLVENT_GRAPHS[solvent_name] = graph
        print(f"{solvent_name}: {graph.num_nodes} nodes, {graph.num_edges} edges")
    except Exception as e:
        print(f"ERROR {solvent_name}: {e}")

print(f"\nTotal graphs: {len(SOLVENT_GRAPHS)}")

Cyclohexane: 6 nodes, 12 edges
Ethyl Acetate: 6 nodes, 10 edges
Acetic Acid: 4 nodes, 6 edges
2-Methyltetrahydrofuran [2-MeTHF]: 6 nodes, 12 edges
1,1,1,3,3,3-Hexafluoropropan-2-ol: 10 nodes, 18 edges
IPA [Propan-2-ol]: 4 nodes, 6 edges
Ethanol: 3 nodes, 4 edges
Methanol: 2 nodes, 2 edges
Ethylene Glycol [1,2-Ethanediol]: 4 nodes, 6 edges
Acetonitrile: 3 nodes, 4 edges
Water: 1 nodes, 1 edges
Diethyl Ether [Ether]: 5 nodes, 8 edges
MTBE [tert-Butylmethylether]: 6 nodes, 10 edges
Dimethyl Carbonate: 6 nodes, 10 edges
tert-Butanol [2-Methylpropan-2-ol]: 5 nodes, 8 edges
DMA [N,N-Dimethylacetamide]: 6 nodes, 10 edges
2,2,2-Trifluoroethanol: 6 nodes, 10 edges
Dihydrolevoglucosenone (Cyrene): 9 nodes, 20 edges
Decanol: 11 nodes, 20 edges
Butanone [MEK]: 5 nodes, 8 edges
Ethyl Lactate: 8 nodes, 14 edges
Methyl Propionate: 6 nodes, 10 edges
THF [Tetrahydrofuran]: 5 nodes, 10 edges
Water.Acetonitrile: 1 nodes, 1 edges
Acetonitrile.Acetic Acid: 3 nodes, 4 edges
Water.2,2,2-Trifluoroethanol: 1 n

In [6]:
# GNN Model using AttentiveFP
class GNNModel(nn.Module):
    def __init__(self, data='single'):
        super().__init__()
        self.data_type = data
        
        # AttentiveFP for molecular property prediction
        self.gnn = AttentiveFP(
            in_channels=6,      # atom features
            hidden_channels=64,
            out_channels=32,    # embedding dim
            edge_dim=3,         # edge features
            num_layers=2,
            num_timesteps=2,
            dropout=0.1
        )
        
        # Kinetics features: T, t, 1/T, ln(t), interaction
        kinetics_dim = 5
        
        if data == 'single':
            input_dim = 32 + kinetics_dim  # GNN embedding + kinetics
        else:
            input_dim = 64 + kinetics_dim + 1  # 2 GNN embeddings + kinetics + pct
        
        # Prediction head
        self.predictor = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 3),
            nn.Sigmoid()
        )
    
    def forward(self, graph_batch, kinetics, pct=None, graph_batch_b=None):
        # Get GNN embeddings
        emb_a = self.gnn(graph_batch.x, graph_batch.edge_index, 
                        graph_batch.edge_attr, graph_batch.batch)
        
        if self.data_type == 'single':
            x = torch.cat([emb_a, kinetics], dim=1)
        else:
            emb_b = self.gnn(graph_batch_b.x, graph_batch_b.edge_index,
                           graph_batch_b.edge_attr, graph_batch_b.batch)
            x = torch.cat([emb_a, emb_b, pct.unsqueeze(1), kinetics], dim=1)
        
        return self.predictor(x)

print('GNNModel defined')

GNNModel defined


In [7]:
# GNN Wrapper for training and prediction
class GNNWrapper:
    def __init__(self, data='single', n_models=3):
        self.data_type = data
        self.n_models = n_models
        self.models = []
        self.solvent_graphs = SOLVENT_GRAPHS
    
    def _get_kinetics(self, X):
        """Extract kinetics features: time, temp, 1/T, ln(t), interaction"""
        X_vals = X[INPUT_LABELS_NUMERIC].values.astype(np.float32)
        temp_c = X_vals[:, 1:2]
        time_m = X_vals[:, 0:1]
        temp_k = temp_c + 273.15
        inv_temp = 1000.0 / temp_k
        log_time = np.log(time_m + 1e-6)
        interaction = inv_temp * log_time
        return np.hstack([X_vals, inv_temp, log_time, interaction])
    
    def _get_graphs(self, solvent_names):
        """Get list of graphs for solvent names"""
        graphs = []
        for name in solvent_names:
            if name in self.solvent_graphs:
                graphs.append(self.solvent_graphs[name])
            else:
                # Fallback: create a simple graph
                print(f"Warning: No graph for {name}")
                graphs.append(self.solvent_graphs['Water'])
        return graphs
    
    def train_model(self, X_train, y_train, epochs=200, batch_size=32, lr=1e-3):
        kinetics = torch.tensor(self._get_kinetics(X_train), dtype=torch.float32)
        y_vals = torch.tensor(y_train.values, dtype=torch.float32)
        
        if self.data_type == 'single':
            graphs = self._get_graphs(X_train["SOLVENT NAME"].values)
        else:
            graphs_a = self._get_graphs(X_train["SOLVENT A NAME"].values)
            graphs_b = self._get_graphs(X_train["SOLVENT B NAME"].values)
            pct = torch.tensor(X_train["SolventB%"].values, dtype=torch.float32)
        
        self.models = []
        for i in range(self.n_models):
            torch.manual_seed(42 + i)
            model = GNNModel(data=self.data_type).to(device)
            optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
            criterion = nn.HuberLoss()
            
            model.train()
            n_samples = len(kinetics)
            
            for epoch in range(epochs):
                # Shuffle indices
                indices = torch.randperm(n_samples)
                
                for start in range(0, n_samples, batch_size):
                    end = min(start + batch_size, n_samples)
                    batch_idx = indices[start:end]
                    
                    # Get batch data
                    batch_kinetics = kinetics[batch_idx].to(device)
                    batch_y = y_vals[batch_idx].to(device)
                    
                    if self.data_type == 'single':
                        batch_graphs = [graphs[j] for j in batch_idx]
                        batch_graph = Batch.from_data_list(batch_graphs).to(device)
                        pred = model(batch_graph, batch_kinetics)
                    else:
                        batch_graphs_a = [graphs_a[j] for j in batch_idx]
                        batch_graphs_b = [graphs_b[j] for j in batch_idx]
                        batch_graph_a = Batch.from_data_list(batch_graphs_a).to(device)
                        batch_graph_b = Batch.from_data_list(batch_graphs_b).to(device)
                        batch_pct = pct[batch_idx].to(device)
                        pred = model(batch_graph_a, batch_kinetics, batch_pct, batch_graph_b)
                    
                    optimizer.zero_grad()
                    loss = criterion(pred, batch_y)
                    loss.backward()
                    optimizer.step()
            
            model.eval()
            self.models.append(model)
    
    def predict(self, X_test):
        kinetics = torch.tensor(self._get_kinetics(X_test), dtype=torch.float32).to(device)
        
        if self.data_type == 'single':
            graphs = self._get_graphs(X_test["SOLVENT NAME"].values)
            graph_batch = Batch.from_data_list(graphs).to(device)
        else:
            graphs_a = self._get_graphs(X_test["SOLVENT A NAME"].values)
            graphs_b = self._get_graphs(X_test["SOLVENT B NAME"].values)
            graph_batch_a = Batch.from_data_list(graphs_a).to(device)
            graph_batch_b = Batch.from_data_list(graphs_b).to(device)
            pct = torch.tensor(X_test["SolventB%"].values, dtype=torch.float32).to(device)
        
        preds = []
        with torch.no_grad():
            for model in self.models:
                if self.data_type == 'single':
                    pred = model(graph_batch, kinetics)
                else:
                    pred = model(graph_batch_a, kinetics, pct, graph_batch_b)
                preds.append(pred.cpu())
        
        return torch.clamp(torch.stack(preds).mean(dim=0), 0, 1).double()

print('GNNWrapper defined')

GNNWrapper defined


In [8]:
# Quick test on single fold
X_single, Y_single = load_data("single_solvent")
test_solvent = sorted(X_single["SOLVENT NAME"].unique())[0]
mask = X_single["SOLVENT NAME"] != test_solvent

print(f"Test solvent: {test_solvent}")
print(f"Training samples: {mask.sum()}, Test samples: {(~mask).sum()}")

model = GNNWrapper(data='single', n_models=1)
model.train_model(X_single[mask], Y_single[mask], epochs=50)
preds = model.predict(X_single[~mask])

actuals = Y_single[~mask].values
mse = np.mean((actuals - preds.numpy()) ** 2)
print(f'Test fold MSE: {mse:.6f}')
print(f'Predictions shape: {preds.shape}')

Test solvent: 1,1,1,3,3,3-Hexafluoropropan-2-ol
Training samples: 619, Test samples: 37


Test fold MSE: 0.068767
Predictions shape: torch.Size([37, 3])


In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNWrapper(data='single', n_models=3)  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y, epochs=100)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNWrapper(data='full', n_models=3)  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y, epochs=100)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
# Calculate CV score (for verification only - NOT part of submission)
X_single, Y_single = load_data("single_solvent")
X_full, Y_full = load_data("full")

# Get actuals in same order as predictions
actuals_single = []
for solvent in sorted(X_single["SOLVENT NAME"].unique()):
    mask = X_single["SOLVENT NAME"] == solvent
    actuals_single.append(Y_single[mask].values)
actuals_single = np.vstack(actuals_single)

actuals_full = []
ramps = X_full[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
for _, row in ramps.iterrows():
    mask = (X_full["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X_full["SOLVENT B NAME"] == row["SOLVENT B NAME"])
    actuals_full.append(Y_full[mask].values)
actuals_full = np.vstack(actuals_full)

# Get predictions
preds_single = submission_single_solvent[['target_1', 'target_2', 'target_3']].values
preds_full = submission_full_data[['target_1', 'target_2', 'target_3']].values

# Calculate MSE
mse_single = np.mean((actuals_single - preds_single) ** 2)
mse_full = np.mean((actuals_full - preds_full) ** 2)
n_single = len(actuals_single)
n_full = len(actuals_full)
overall_mse = (mse_single * n_single + mse_full * n_full) / (n_single + n_full)

print(f'\n=== CV SCORE VERIFICATION ===')
print(f'Single Solvent MSE: {mse_single:.6f} (n={n_single})')
print(f'Full Data MSE: {mse_full:.6f} (n={n_full})')
print(f'Overall MSE: {overall_mse:.6f}')
print(f'\nexp_035 baseline (GP+MLP+LGBM): CV 0.008194')
print(f'This (GNN): CV {overall_mse:.6f}')

if overall_mse < 0.008194:
    improvement = (0.008194 - overall_mse) / 0.008194 * 100
    print(f'\n✓ IMPROVEMENT: {improvement:.2f}% better than exp_035!')
else:
    degradation = (overall_mse - 0.008194) / 0.008194 * 100
    print(f'\n✗ WORSE: {degradation:.2f}% worse than exp_035')