# Experiment 029: Improved GNN with GAT and Edge Features

**Goal**: Implement a proper GNN approach that can break the CV-LB relationship plateau.

**Key improvements over exp_020 (basic GNN, CV 0.099)**:
1. Use Graph Attention Networks (GAT) instead of GCN
2. Add edge features (bond type, aromaticity, conjugation)
3. Multi-task learning across all 3 targets
4. Proper learning rate schedule (cosine annealing)
5. Combine molecular graph features with process conditions

**Hypothesis**: GNN can learn generalizable molecular representations that transfer better to unseen solvents.

**TEMPLATE COMPLIANCE**: Last 3 cells are EXACTLY as template.

In [5]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, Batch
from sklearn.preprocessing import StandardScaler
from rdkit import Chem
from rdkit.Chem import AllChem
from abc import ABC
import tqdm
import warnings
warnings.filterwarnings('ignore')

DATA_PATH = '/home/data'
torch.set_default_dtype(torch.double)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")

Device: cuda
PyTorch: 2.2.0+cu118


In [6]:
# --- UTILITY FUNCTIONS ---
TARGET_LABELS = ["Product 2", "Product 3", "SM"]

def load_data(name="full"):
    assert name in ["full", "single_solvent"]
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT NAME"]]
    Y = df[TARGET_LABELS]
    return X, Y

def load_features(name="spange_descriptors"):
    return pd.read_csv(f'{DATA_PATH}/{name}_lookup.csv', index_col=0)

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

# Load feature sets
SPANGE_DF = load_features('spange_descriptors')
ACS_PCA_DF = load_features('acs_pca_descriptors')
print(f"Spange: {SPANGE_DF.shape}, ACS_PCA: {ACS_PCA_DF.shape}")

Spange: (26, 13), ACS_PCA: (24, 5)


In [7]:
# --- BASE CLASSES ---
class SmilesFeaturizer(ABC):
    def __init__(self): raise NotImplementedError
    def featurize(self, X): raise NotImplementedError

class BaseModel(ABC):
    def __init__(self): pass
    def train_model(self, X_train, y_train): raise NotImplementedError
    def predict(self): raise NotImplementedError

In [8]:
# --- MOLECULAR GRAPH UTILITIES ---

# Load SMILES lookup
smiles_lookup = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f"SMILES lookup: {smiles_lookup.shape}")
print(smiles_lookup.head())

def get_atom_features(atom):
    """Get atom features for GNN."""
    # Atom type one-hot (C, N, O, F, S, Cl, Br, I, other)
    atom_types = ['C', 'N', 'O', 'F', 'S', 'Cl', 'Br', 'I']
    atom_type = [1 if atom.GetSymbol() == t else 0 for t in atom_types]
    atom_type.append(1 if atom.GetSymbol() not in atom_types else 0)  # other
    
    # Degree (0-5)
    degree = [1 if atom.GetDegree() == i else 0 for i in range(6)]
    
    # Hybridization
    hyb = atom.GetHybridization()
    hybridization = [
        1 if hyb == Chem.rdchem.HybridizationType.SP else 0,
        1 if hyb == Chem.rdchem.HybridizationType.SP2 else 0,
        1 if hyb == Chem.rdchem.HybridizationType.SP3 else 0,
    ]
    
    # Other features
    other = [
        atom.GetIsAromatic() * 1.0,
        atom.GetFormalCharge() / 2.0,  # normalize
        atom.GetNumRadicalElectrons(),
        atom.IsInRing() * 1.0,
    ]
    
    return atom_type + degree + hybridization + other  # 9 + 6 + 3 + 4 = 22 features

def get_bond_features(bond):
    """Get bond features for GNN."""
    bond_type = bond.GetBondType()
    features = [
        1 if bond_type == Chem.rdchem.BondType.SINGLE else 0,
        1 if bond_type == Chem.rdchem.BondType.DOUBLE else 0,
        1 if bond_type == Chem.rdchem.BondType.TRIPLE else 0,
        1 if bond_type == Chem.rdchem.BondType.AROMATIC else 0,
        bond.GetIsConjugated() * 1.0,
        bond.IsInRing() * 1.0,
    ]
    return features  # 6 features

def smiles_to_graph(smiles):
    """Convert SMILES to PyG graph."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Get atom features
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(get_atom_features(atom))
    x = torch.tensor(atom_features, dtype=torch.double)
    
    # Get edge indices and features
    edge_index = []
    edge_attr = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.extend([[i, j], [j, i]])  # bidirectional
        bf = get_bond_features(bond)
        edge_attr.extend([bf, bf])  # same features for both directions
    
    if len(edge_index) == 0:
        # Single atom molecule
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 6), dtype=torch.double)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.double)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# Pre-compute graphs for all solvents
SOLVENT_GRAPHS = {}
for solvent_name in smiles_lookup.index:
    smiles = smiles_lookup.loc[solvent_name, 'solvent smiles']  # Fixed column name
    graph = smiles_to_graph(smiles)
    if graph is not None:
        SOLVENT_GRAPHS[solvent_name] = graph
        
print(f"Pre-computed {len(SOLVENT_GRAPHS)} solvent graphs")
print(f"Example graph: {list(SOLVENT_GRAPHS.values())[0]}")

SMILES lookup: (26, 1)
                                           solvent smiles
SOLVENT NAME                                             
Cyclohexane                                      C1CCCCC1
Ethyl Acetate                                   O=C(OCC)C
Acetic Acid                                       CC(=O)O
2-Methyltetrahydrofuran [2-MeTHF]              O1C(C)CCC1
1,1,1,3,3,3-Hexafluoropropan-2-ol  C(C(F)(F)F)(C(F)(F)F)O
Pre-computed 26 solvent graphs
Example graph: Data(x=[6, 22], edge_index=[2, 12], edge_attr=[12, 6])


In [9]:
# --- GAT-BASED GNN MODEL ---

class GATEncoder(nn.Module):
    """Graph Attention Network encoder for molecular graphs."""
    def __init__(self, node_dim=22, edge_dim=6, hidden_dim=64, num_heads=4, dropout=0.1):
        super().__init__()
        
        # Node embedding
        self.node_embed = nn.Linear(node_dim, hidden_dim)
        
        # Edge embedding
        self.edge_embed = nn.Linear(edge_dim, hidden_dim)
        
        # GAT layers
        self.conv1 = GATConv(hidden_dim, hidden_dim, heads=num_heads, dropout=dropout, edge_dim=hidden_dim)
        self.conv2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout, edge_dim=hidden_dim)
        self.conv3 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, dropout=dropout, edge_dim=hidden_dim, concat=False)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, edge_index, edge_attr, batch):
        # Embed nodes and edges
        x = self.node_embed(x)
        edge_attr = self.edge_embed(edge_attr)
        
        # GAT layers with residual connections
        x1 = F.elu(self.conv1(x, edge_index, edge_attr))
        x1 = self.dropout(x1)
        
        x2 = F.elu(self.conv2(x1, edge_index, edge_attr))
        x2 = self.dropout(x2)
        
        x3 = self.conv3(x2, edge_index, edge_attr)
        
        # Global pooling (mean + max)
        x_mean = global_mean_pool(x3, batch)
        x_max = global_max_pool(x3, batch)
        
        return torch.cat([x_mean, x_max], dim=1)  # 2 * hidden_dim


class GNNYieldPredictor(nn.Module):
    """Full model: GNN encoder + process conditions -> yield predictions."""
    def __init__(self, node_dim=22, edge_dim=6, hidden_dim=64, process_dim=5, spange_dim=13):
        super().__init__()
        
        self.gnn = GATEncoder(node_dim, edge_dim, hidden_dim)
        
        # Process condition encoder
        self.process_encoder = nn.Sequential(
            nn.Linear(process_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Spange feature encoder (physics-based)
        self.spange_encoder = nn.Sequential(
            nn.Linear(spange_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Combined predictor (multi-task)
        combined_dim = 2 * hidden_dim + 32 + 32  # GNN + process + spange
        self.predictor = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 3)  # 3 targets: Product 2, Product 3, SM
        )
        
    def forward(self, graph_batch, process_feats, spange_feats):
        # Encode molecular graph
        mol_embed = self.gnn(graph_batch.x, graph_batch.edge_index, graph_batch.edge_attr, graph_batch.batch)
        
        # Encode process conditions
        proc_embed = self.process_encoder(process_feats)
        
        # Encode Spange features
        spange_embed = self.spange_encoder(spange_feats)
        
        # Combine and predict
        combined = torch.cat([mol_embed, proc_embed, spange_embed], dim=1)
        output = self.predictor(combined)
        
        return torch.sigmoid(output)  # Yields are 0-1

print("GNN model defined")

GNN model defined


In [10]:
# --- GNN MODEL WRAPPER FOR COMPETITION ---

class GNNModel(BaseModel):
    """GNN-based yield predictor with GAT."""
    
    def __init__(self, data='single', hidden_dim=64, lr=1e-3, epochs=100):
        self.data_type = data
        self.mixed = (data == 'full')
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.epochs = epochs
        self.device = device
        
        # Feature dimensions
        self.process_dim = 6 if self.mixed else 5  # RT, Temp, inv_temp, log_time, interaction, [pct]
        self.spange_dim = SPANGE_DF.shape[1]
        
        # Scalers
        self.process_scaler = StandardScaler()
        self.spange_scaler = StandardScaler()
        
        self.model = None
        
    def _get_solvent_graph(self, solvent_name):
        """Get pre-computed graph for a solvent."""
        if solvent_name in SOLVENT_GRAPHS:
            return SOLVENT_GRAPHS[solvent_name]
        else:
            # Fallback: create empty graph
            return Data(
                x=torch.zeros((1, 22), dtype=torch.double),
                edge_index=torch.zeros((2, 0), dtype=torch.long),
                edge_attr=torch.zeros((0, 6), dtype=torch.double)
            )
    
    def _build_features(self, X):
        """Build process and Spange features."""
        rt = X['Residence Time'].values.astype(np.float64).reshape(-1, 1)
        temp = X['Temperature'].values.astype(np.float64).reshape(-1, 1)
        temp_k = temp + 273.15
        inv_temp = 1000.0 / temp_k
        log_time = np.log(rt + 1e-6)
        interaction = inv_temp * log_time
        
        if self.mixed:
            pct = X['SolventB%'].values.reshape(-1, 1)
            process_feats = np.hstack([rt, temp, inv_temp, log_time, interaction, pct])
            
            # Interpolated Spange features
            A = SPANGE_DF.loc[X['SOLVENT A NAME']].values
            B = SPANGE_DF.loc[X['SOLVENT B NAME']].values
            spange_feats = A * (1 - pct) + B * pct
        else:
            process_feats = np.hstack([rt, temp, inv_temp, log_time, interaction])
            spange_feats = SPANGE_DF.loc[X['SOLVENT NAME']].values
            
        return process_feats, spange_feats
    
    def _get_graphs_batch(self, X):
        """Get batch of molecular graphs."""
        graphs = []
        if self.mixed:
            for i in range(len(X)):
                pct = X['SolventB%'].iloc[i]
                # Use solvent A graph (primary solvent)
                solvent_a = X['SOLVENT A NAME'].iloc[i]
                graphs.append(self._get_solvent_graph(solvent_a))
        else:
            for solvent in X['SOLVENT NAME']:
                graphs.append(self._get_solvent_graph(solvent))
        return Batch.from_data_list(graphs)
    
    def train_model(self, X_train, y_train):
        # Build features
        process_feats, spange_feats = self._build_features(X_train)
        
        # Scale
        process_feats = self.process_scaler.fit_transform(process_feats)
        spange_feats = self.spange_scaler.fit_transform(spange_feats)
        
        # Get graphs
        graph_batch = self._get_graphs_batch(X_train)
        
        # Convert to tensors
        process_tensor = torch.tensor(process_feats, dtype=torch.double).to(self.device)
        spange_tensor = torch.tensor(spange_feats, dtype=torch.double).to(self.device)
        y_tensor = torch.tensor(y_train.values, dtype=torch.double).to(self.device)
        graph_batch = graph_batch.to(self.device)
        
        # Initialize model
        self.model = GNNYieldPredictor(
            node_dim=22, edge_dim=6, hidden_dim=self.hidden_dim,
            process_dim=self.process_dim, spange_dim=self.spange_dim
        ).double().to(self.device)
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)
        
        # Training loop
        self.model.train()
        for epoch in range(self.epochs):
            optimizer.zero_grad()
            
            # Forward pass
            preds = self.model(graph_batch, process_tensor, spange_tensor)
            
            # MAE loss
            loss = F.l1_loss(preds, y_tensor)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            scheduler.step()
            
        return self
    
    def predict(self, X):
        # Build features
        process_feats, spange_feats = self._build_features(X)
        
        # Scale
        process_feats = self.process_scaler.transform(process_feats)
        spange_feats = self.spange_scaler.transform(spange_feats)
        
        # Get graphs
        graph_batch = self._get_graphs_batch(X)
        
        # Convert to tensors
        process_tensor = torch.tensor(process_feats, dtype=torch.double).to(self.device)
        spange_tensor = torch.tensor(spange_feats, dtype=torch.double).to(self.device)
        graph_batch = graph_batch.to(self.device)
        
        # Predict
        self.model.eval()
        with torch.no_grad():
            preds = self.model(graph_batch, process_tensor, spange_tensor)
        
        return preds.cpu()

print("GNNModel wrapper defined")

GNNModel wrapper defined


In [11]:
# Quick test on a few folds
print("Testing GNN model...")
X_test, Y_test = load_data("single_solvent")

errors = []
for i, ((train_X, train_Y), (test_X, test_Y)) in enumerate(generate_leave_one_out_splits(X_test, Y_test)):
    if i >= 3: break
    
    model = GNNModel(data='single', hidden_dim=64, lr=1e-3, epochs=50)
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mae = np.mean(np.abs(preds - test_Y.values))
    errors.append(mae)
    solvent = test_X['SOLVENT NAME'].iloc[0]
    print(f"Fold {i} ({solvent}): MAE = {mae:.4f}")

print(f"\nQuick test MAE: {np.mean(errors):.4f}")

Testing GNN model...


Fold 0 (1,1,1,3,3,3-Hexafluoropropan-2-ol): MAE = 0.2740


Fold 1 (2,2,2-Trifluoroethanol): MAE = 0.1711


Fold 2 (2-Methyltetrahydrofuran [2-MeTHF]): MAE = 0.1719

Quick test MAE: 0.2057


In [12]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModel(data='single', hidden_dim=64, lr=1e-3, epochs=100) # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

0it [00:00, ?it/s]

1it [00:00,  1.55it/s]

2it [00:01,  1.60it/s]

3it [00:01,  1.61it/s]

4it [00:02,  1.63it/s]

5it [00:03,  1.63it/s]

6it [00:03,  1.63it/s]

7it [00:04,  1.63it/s]

8it [00:04,  1.62it/s]

9it [00:05,  1.62it/s]

10it [00:06,  1.63it/s]

11it [00:06,  1.63it/s]

12it [00:07,  1.63it/s]

13it [00:08,  1.63it/s]

14it [00:08,  1.62it/s]

15it [00:09,  1.62it/s]

16it [00:09,  1.63it/s]

17it [00:10,  1.59it/s]

18it [00:11,  1.60it/s]

19it [00:11,  1.60it/s]

20it [00:12,  1.61it/s]

21it [00:12,  1.62it/s]

22it [00:13,  1.62it/s]

23it [00:14,  1.63it/s]

24it [00:14,  1.63it/s]

24it [00:14,  1.62it/s]




In [13]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = GNNModel(data='full', hidden_dim=64, lr=1e-3, epochs=100) # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

0it [00:00, ?it/s]

1it [00:00,  1.46it/s]

2it [00:01,  1.52it/s]

3it [00:01,  1.52it/s]

4it [00:02,  1.52it/s]

5it [00:03,  1.53it/s]

6it [00:03,  1.53it/s]

7it [00:04,  1.54it/s]

8it [00:05,  1.53it/s]

9it [00:05,  1.52it/s]

10it [00:06,  1.52it/s]

11it [00:07,  1.53it/s]

12it [00:07,  1.53it/s]

13it [00:08,  1.52it/s]

13it [00:08,  1.52it/s]




In [14]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################