# Advanced GNN with GAT and Proper Mixture Handling

**Hypothesis**: Our previous GNN (CV 0.01408) was 3.6x worse than the benchmark (CV 0.0039). Key improvements:
1. Use GATConv (Graph Attention) instead of GCNConv
2. Properly handle mixtures: encode BOTH solvent graphs, combine with attention
3. Add edge features (bond types)
4. Use attention-based pooling (GlobalAttention)
5. Increase model capacity and train longer

**Target**: Achieve CV < 0.006 (would predict LB ≈ 0.079)

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
import tqdm
import warnings
warnings.filterwarnings('ignore')

# PyTorch Geometric imports
from torch_geometric.nn import GATConv, GlobalAttention, global_mean_pool
from torch_geometric.utils import from_smiles
from torch_geometric.data import Data, Batch

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# Data loading functions
DATA_PATH = '/home/data'

INPUT_LABELS_NUMERIC = ["Residence Time", "Temperature"]
INPUT_LABELS_SINGLE_SOLVENT = ["Residence Time", "Temperature", "SOLVENT NAME"]
INPUT_LABELS_FULL_SOLVENT = ["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[INPUT_LABELS_FULL_SOLVENT]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[INPUT_LABELS_SINGLE_SOLVENT]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

Data loading functions defined


In [3]:
# Load SMILES lookup and pre-compute graphs
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f'SMILES lookup: {SMILES_DF.shape}')

# Pre-compute graph data for all solvents
SOLVENT_GRAPHS = {}
for solvent in SMILES_DF.index:
    smiles = SMILES_DF.loc[solvent, 'solvent smiles']
    # Handle mixture solvents (e.g., "Water.Acetonitrile")
    if '.' in smiles:
        parts = smiles.split('.')
        graphs = [from_smiles(s) for s in parts]
        SOLVENT_GRAPHS[solvent] = graphs
    else:
        SOLVENT_GRAPHS[solvent] = from_smiles(smiles)

print(f'Pre-computed graphs for {len(SOLVENT_GRAPHS)} solvents')

SMILES lookup: (26, 1)
Pre-computed graphs for 26 solvents


In [4]:
# Advanced GAT-based GNN Encoder
class GATEncoder(nn.Module):
    """Graph Attention Network encoder for molecular graphs."""
    def __init__(self, node_features=9, hidden_dim=64, output_dim=64, heads=4, dropout=0.2):
        super().__init__()
        # Multi-head attention layers
        self.conv1 = GATConv(node_features, hidden_dim, heads=heads, concat=True, dropout=dropout)
        self.conv2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, concat=True, dropout=dropout)
        self.conv3 = GATConv(hidden_dim * heads, output_dim, heads=1, concat=False, dropout=dropout)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm1d(hidden_dim * heads)
        self.bn2 = nn.BatchNorm1d(hidden_dim * heads)
        
        # Attention-based pooling
        gate_nn = nn.Sequential(
            nn.Linear(output_dim, output_dim // 2),
            nn.ReLU(),
            nn.Linear(output_dim // 2, 1)
        )
        self.pool = GlobalAttention(gate_nn)
        
    def forward(self, x, edge_index, batch):
        # First GAT layer
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.elu(x)
        
        # Second GAT layer
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        
        # Third GAT layer
        x = self.conv3(x, edge_index)
        x = F.elu(x)
        
        # Attention-based pooling
        x = self.pool(x, batch)
        return x

print('GAT Encoder defined')

GAT Encoder defined


In [5]:
# Advanced GNN Model with proper mixture handling
class AdvancedGNNModel(nn.Module):
    """Advanced GNN model with GAT encoder and proper mixture handling."""
    def __init__(self, node_features=9, hidden_dim=64, graph_dim=64, output_dim=3, dropout=0.2):
        super().__init__()
        self.gnn = GATEncoder(node_features, hidden_dim, graph_dim, heads=4, dropout=dropout)
        
        # Mixture attention: learn to combine two solvent embeddings
        self.mixture_attention = nn.Sequential(
            nn.Linear(graph_dim * 2 + 1, graph_dim),  # +1 for mixture percentage
            nn.ReLU(),
            nn.Linear(graph_dim, 2),
            nn.Softmax(dim=1)
        )
        
        # Condition encoder (temperature, time, kinetics)
        self.condition_encoder = nn.Sequential(
            nn.Linear(5, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )
        
        # Final prediction head
        self.fc1 = nn.Linear(graph_dim + 32, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        
    def forward(self, graph_data, conditions, graph_data_b=None, mixture_pct=None):
        # Encode primary solvent graph
        graph_emb_a = self.gnn(graph_data.x, graph_data.edge_index, graph_data.batch)
        
        if graph_data_b is not None and mixture_pct is not None:
            # Encode secondary solvent graph
            graph_emb_b = self.gnn(graph_data_b.x, graph_data_b.edge_index, graph_data_b.batch)
            
            # Compute attention weights for mixture
            mixture_input = torch.cat([graph_emb_a, graph_emb_b, mixture_pct], dim=1)
            attention_weights = self.mixture_attention(mixture_input)
            
            # Weighted combination of embeddings
            graph_emb = attention_weights[:, 0:1] * graph_emb_a + attention_weights[:, 1:2] * graph_emb_b
        else:
            graph_emb = graph_emb_a
        
        # Encode conditions
        cond_emb = self.condition_encoder(conditions)
        
        # Combine and predict
        x = torch.cat([graph_emb, cond_emb], dim=1)
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc3(x))
        return x

print('Advanced GNN Model defined')

Advanced GNN Model defined


In [6]:
# Advanced GNN Model Wrapper
class AdvancedGNNWrapper:
    def __init__(self, data='single', epochs=500, lr=0.001, hidden_dim=64, graph_dim=64, dropout=0.2):
        self.data = data
        self.epochs = epochs
        self.lr = lr
        self.hidden_dim = hidden_dim
        self.graph_dim = graph_dim
        self.dropout = dropout
        self.model = None
        self.scaler = StandardScaler()
        
    def _get_conditions(self, X):
        """Extract and transform condition features."""
        temp_c = X["Temperature"].values.reshape(-1, 1)
        time_m = X["Residence Time"].values.reshape(-1, 1)
        temp_k = temp_c + 273.15
        inv_temp = 1000.0 / temp_k
        log_time = np.log(time_m + 1e-6)
        interaction = inv_temp * log_time
        return np.hstack([temp_c, time_m, inv_temp, log_time, interaction])
    
    def _get_graph(self, solvent_name):
        """Get graph for a solvent, handling mixtures."""
        g = SOLVENT_GRAPHS.get(solvent_name)
        if g is None or isinstance(g, list):
            # Fallback for unknown or mixture solvents
            g = from_smiles('C')  # Methane as fallback
        return Data(x=g.x.float(), edge_index=g.edge_index)
    
    def _get_graph_batch(self, X):
        """Create batches of graphs."""
        if self.data == 'single':
            graphs = [self._get_graph(row["SOLVENT NAME"]) for _, row in X.iterrows()]
            return Batch.from_data_list(graphs), None, None
        else:
            graphs_a = [self._get_graph(row["SOLVENT A NAME"]) for _, row in X.iterrows()]
            graphs_b = [self._get_graph(row["SOLVENT B NAME"]) for _, row in X.iterrows()]
            mixture_pct = X["SolventB%"].values.reshape(-1, 1)
            return Batch.from_data_list(graphs_a), Batch.from_data_list(graphs_b), mixture_pct
    
    def train_model(self, X, Y):
        """Train the advanced GNN model."""
        # Get conditions
        conditions = self._get_conditions(X)
        conditions_scaled = self.scaler.fit_transform(conditions)
        
        # Get graph batches
        graph_batch_a, graph_batch_b, mixture_pct = self._get_graph_batch(X)
        
        # Convert to tensors
        conditions_tensor = torch.tensor(conditions_scaled, dtype=torch.float32).to(device)
        y_tensor = torch.tensor(Y.values, dtype=torch.float32).to(device)
        graph_batch_a = graph_batch_a.to(device)
        if graph_batch_b is not None:
            graph_batch_b = graph_batch_b.to(device)
            mixture_pct_tensor = torch.tensor(mixture_pct, dtype=torch.float32).to(device)
        else:
            mixture_pct_tensor = None
        
        # Initialize model
        self.model = AdvancedGNNModel(
            node_features=9,
            hidden_dim=self.hidden_dim,
            graph_dim=self.graph_dim,
            output_dim=3,
            dropout=self.dropout
        ).to(device)
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs)
        criterion = nn.MSELoss()
        
        # Training loop
        self.model.train()
        best_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(self.epochs):
            optimizer.zero_grad()
            
            # Forward pass
            predictions = self.model(graph_batch_a, conditions_tensor, graph_batch_b, mixture_pct_tensor)
            loss = criterion(predictions, y_tensor)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= 100:
                    break
        
        return self
    
    def predict(self, X):
        """Make predictions."""
        self.model.eval()
        
        # Get conditions
        conditions = self._get_conditions(X)
        conditions_scaled = self.scaler.transform(conditions)
        
        # Get graph batches
        graph_batch_a, graph_batch_b, mixture_pct = self._get_graph_batch(X)
        
        # Convert to tensors
        conditions_tensor = torch.tensor(conditions_scaled, dtype=torch.float32).to(device)
        graph_batch_a = graph_batch_a.to(device)
        if graph_batch_b is not None:
            graph_batch_b = graph_batch_b.to(device)
            mixture_pct_tensor = torch.tensor(mixture_pct, dtype=torch.float32).to(device)
        else:
            mixture_pct_tensor = None
        
        with torch.no_grad():
            predictions = self.model(graph_batch_a, conditions_tensor, graph_batch_b, mixture_pct_tensor)
        
        # TTA for mixtures (flip A and B)
        if self.data == 'full':
            # Swap A and B
            graphs_b = [self._get_graph(row["SOLVENT A NAME"]) for _, row in X.iterrows()]
            graphs_a = [self._get_graph(row["SOLVENT B NAME"]) for _, row in X.iterrows()]
            graph_batch_a_flip = Batch.from_data_list(graphs_a).to(device)
            graph_batch_b_flip = Batch.from_data_list(graphs_b).to(device)
            mixture_pct_flip = 1.0 - mixture_pct
            mixture_pct_flip_tensor = torch.tensor(mixture_pct_flip, dtype=torch.float32).to(device)
            
            with torch.no_grad():
                predictions_flip = self.model(graph_batch_a_flip, conditions_tensor, graph_batch_b_flip, mixture_pct_flip_tensor)
            predictions = (predictions + predictions_flip) / 2
        
        # Clip to [0, 1]
        predictions = torch.clamp(predictions, 0, 1)
        
        return predictions

print('Advanced GNN Wrapper defined')

Advanced GNN Wrapper defined


In [7]:
# Quick test
X_single, Y_single = load_data("single_solvent")
print(f'Single solvent data: X={X_single.shape}, Y={Y_single.shape}')

# Test on a small subset
X_test = X_single.iloc[:50]
Y_test = Y_single.iloc[:50]

model = AdvancedGNNWrapper(data='single', epochs=100)
model.train_model(X_test, Y_test)
preds = model.predict(X_test)
print(f'Test predictions shape: {preds.shape}')
print(f'Test predictions range: [{preds.min():.4f}, {preds.max():.4f}]')

Single solvent data: X=(656, 3), Y=(656, 3)


Test predictions shape: torch.Size([50, 3])
Test predictions range: [0.1308, 0.8841]


In [8]:
# Run CV on single solvent data
print('\n=== Single Solvent CV (Advanced GNN) ===')
X_single, Y_single = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X_single, Y_single)
all_predictions_single = []
all_actuals_single = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=24):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = AdvancedGNNWrapper(data='single', epochs=500, lr=0.001, hidden_dim=64, graph_dim=64)
    model.train_model(train_X, train_Y)
    predictions = model.predict(test_X)
    
    all_predictions_single.append(predictions.cpu().numpy())
    all_actuals_single.append(test_Y.values)

preds_single = np.vstack(all_predictions_single)
actuals_single = np.vstack(all_actuals_single)
mse_single = np.mean((preds_single - actuals_single) ** 2)
print(f'Single Solvent MSE: {mse_single:.6f} (n={len(preds_single)})')


=== Single Solvent CV (Advanced GNN) ===


  0%|          | 0/24 [00:00<?, ?it/s]

  4%|▍         | 1/24 [00:02<01:01,  2.68s/it]

  8%|▊         | 2/24 [00:05<00:57,  2.62s/it]

 12%|█▎        | 3/24 [00:07<00:54,  2.60s/it]

 17%|█▋        | 4/24 [00:10<00:52,  2.61s/it]

 21%|██        | 5/24 [00:13<00:49,  2.61s/it]

 25%|██▌       | 6/24 [00:15<00:47,  2.61s/it]

 29%|██▉       | 7/24 [00:18<00:44,  2.62s/it]

 33%|███▎      | 8/24 [00:20<00:41,  2.62s/it]

 38%|███▊      | 9/24 [00:23<00:39,  2.62s/it]

 42%|████▏     | 10/24 [00:26<00:36,  2.63s/it]

 46%|████▌     | 11/24 [00:28<00:34,  2.62s/it]

 50%|█████     | 12/24 [00:31<00:31,  2.62s/it]

 54%|█████▍    | 13/24 [00:34<00:28,  2.62s/it]

 58%|█████▊    | 14/24 [00:36<00:26,  2.62s/it]

 62%|██████▎   | 15/24 [00:39<00:23,  2.61s/it]

 67%|██████▋   | 16/24 [00:41<00:20,  2.60s/it]

 71%|███████   | 17/24 [00:44<00:18,  2.61s/it]

 75%|███████▌  | 18/24 [00:46<00:15,  2.59s/it]

 79%|███████▉  | 19/24 [00:49<00:12,  2.60s/it]

 83%|████████▎ | 20/24 [00:52<00:10,  2.63s/it]

 88%|████████▊ | 21/24 [00:54<00:07,  2.62s/it]

 92%|█████████▏| 22/24 [00:57<00:05,  2.61s/it]

 96%|█████████▌| 23/24 [01:00<00:02,  2.62s/it]

100%|██████████| 24/24 [01:02<00:00,  2.63s/it]

100%|██████████| 24/24 [01:02<00:00,  2.62s/it]

Single Solvent MSE: 0.016986 (n=656)





In [9]:
# Run CV on full data
print('\n=== Full Data CV (Advanced GNN) ===')
X_full, Y_full = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X_full, Y_full)
all_predictions_full = []
all_actuals_full = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=13):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = AdvancedGNNWrapper(data='full', epochs=500, lr=0.001, hidden_dim=64, graph_dim=64)
    model.train_model(train_X, train_Y)
    predictions = model.predict(test_X)
    
    all_predictions_full.append(predictions.cpu().numpy())
    all_actuals_full.append(test_Y.values)

preds_full = np.vstack(all_predictions_full)
actuals_full = np.vstack(all_actuals_full)
mse_full = np.mean((preds_full - actuals_full) ** 2)
print(f'Full Data MSE: {mse_full:.6f} (n={len(preds_full)})')


=== Full Data CV (Advanced GNN) ===


  0%|          | 0/13 [00:00<?, ?it/s]

  8%|▊         | 1/13 [00:04<00:57,  4.82s/it]

 15%|█▌        | 2/13 [00:09<00:52,  4.79s/it]

 23%|██▎       | 3/13 [00:14<00:47,  4.78s/it]

 31%|███       | 4/13 [00:19<00:43,  4.79s/it]

 38%|███▊      | 5/13 [00:23<00:38,  4.80s/it]

 46%|████▌     | 6/13 [00:28<00:33,  4.83s/it]

 54%|█████▍    | 7/13 [00:33<00:28,  4.82s/it]

 62%|██████▏   | 8/13 [00:38<00:24,  4.82s/it]

 69%|██████▉   | 9/13 [00:43<00:19,  4.82s/it]

 77%|███████▋  | 10/13 [00:48<00:14,  4.82s/it]

 85%|████████▍ | 11/13 [00:52<00:09,  4.81s/it]

 92%|█████████▏| 12/13 [00:57<00:04,  4.83s/it]

100%|██████████| 13/13 [01:02<00:00,  4.84s/it]

100%|██████████| 13/13 [01:02<00:00,  4.82s/it]

Full Data MSE: 0.036978 (n=1227)





In [None]:
# Calculate overall MSE
n_single = len(preds_single)
n_full = len(preds_full)
overall_mse = (mse_single * n_single + mse_full * n_full) / (n_single + n_full)

print(f'\n=== CV SCORE SUMMARY (Advanced GNN) ===')
print(f'Single Solvent MSE: {mse_single:.6f} (n={n_single})')
print(f'Full Data MSE: {mse_full:.6f} (n={n_full})')
print(f'Overall MSE: {overall_mse:.6f}')
print(f'\nBest CV (exp_032): 0.008194')
print(f'Previous GNN (exp_051): 0.01408')
print(f'GNN Benchmark: 0.0039')

if overall_mse < 0.008194:
    improvement = (0.008194 - overall_mse) / 0.008194 * 100
    print(f'\n✓ IMPROVEMENT: {improvement:.2f}% better than best CV!')
else:
    degradation = (overall_mse - 0.008194) / 0.008194 * 100
    print(f'\n✗ WORSE: {degradation:.2f}% worse than best CV')

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = AdvancedGNNWrapper(data='single')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = AdvancedGNNWrapper(data='full')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
# Final verification
print(f'\n=== FINAL CV SCORE ===')
print(f'Single Solvent MSE: {mse_single:.6f} (n={n_single})')
print(f'Full Data MSE: {mse_full:.6f} (n={n_full})')
print(f'Overall MSE: {overall_mse:.6f}')
print(f'\nBest CV (exp_032): 0.008194')
print(f'Previous GNN (exp_051): 0.01408')
print(f'GNN Benchmark: 0.0039')