# Dual-Encoder GNN with Fixed Mixture Handling

**CRITICAL BUG FIX**: Previous GNN (exp_081) only used Solvent A's graph for mixtures, completely ignoring Solvent B. This means 65% of the data (1227 full data samples) was modeled with incomplete information.

**This implementation fixes the bug by:**
1. Using a shared GNN encoder for both solvents
2. Encoding BOTH Solvent A and Solvent B graphs
3. Combining embeddings with weighted pooling: `(1-pct_b)*emb_a + pct_b*emb_b`
4. For single solvent data: pct_b=0, so only emb_a is used

**Expected outcome**: If mixture handling is fixed, GNN should achieve CV closer to tabular (0.008-0.012). The benchmark achieved MSE 0.0039 with proper GNN.

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
# PyTorch Geometric imports
from torch_geometric.utils import from_smiles
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader

print('PyTorch Geometric imports successful')

PyTorch Geometric imports successful


In [3]:
# Data loading functions
DATA_PATH = '/home/data'

def load_data(name="full"):
    if name == "full":
        df = pd.read_csv(f'{DATA_PATH}/catechol_full_data_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT A NAME", "SOLVENT B NAME", "SolventB%"]]
    else:
        df = pd.read_csv(f'{DATA_PATH}/catechol_single_solvent_yields.csv')
        X = df[["Residence Time", "Temperature", "SOLVENT NAME"]]
    Y = df[["Product 2", "Product 3", "SM"]]
    return X, Y

def generate_leave_one_out_splits(X, Y):
    for solvent in sorted(X["SOLVENT NAME"].unique()):
        mask = X["SOLVENT NAME"] != solvent
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

def generate_leave_one_ramp_out_splits(X, Y):
    ramps = X[["SOLVENT A NAME", "SOLVENT B NAME"]].drop_duplicates()
    for _, row in ramps.iterrows():
        mask = ~((X["SOLVENT A NAME"] == row["SOLVENT A NAME"]) & (X["SOLVENT B NAME"] == row["SOLVENT B NAME"]))
        yield (X[mask], Y[mask]), (X[~mask], Y[~mask])

print('Data loading functions defined')

Data loading functions defined


In [4]:
# Load SMILES and pre-compute molecular graphs
SMILES_DF = pd.read_csv(f'{DATA_PATH}/smiles_lookup.csv', index_col=0)
print(f'SMILES lookup shape: {SMILES_DF.shape}')

# Pre-compute molecular graphs for all solvents
print('\nPre-computing molecular graphs...')
SOLVENT_GRAPHS = {}
for solvent_name, row in SMILES_DF.iterrows():
    smiles = row['solvent smiles']
    try:
        graph = from_smiles(smiles)
        SOLVENT_GRAPHS[solvent_name] = graph
        print(f'  {solvent_name}: {graph.x.shape[0]} atoms')
    except Exception as e:
        print(f'  ERROR for {solvent_name}: {e}')

print(f'\nTotal solvents with graphs: {len(SOLVENT_GRAPHS)}')

SMILES lookup shape: (26, 1)

Pre-computing molecular graphs...
  Cyclohexane: 6 atoms
  Ethyl Acetate: 6 atoms
  Acetic Acid: 4 atoms
  2-Methyltetrahydrofuran [2-MeTHF]: 6 atoms
  1,1,1,3,3,3-Hexafluoropropan-2-ol: 10 atoms
  IPA [Propan-2-ol]: 4 atoms
  Ethanol: 3 atoms
  Methanol: 2 atoms
  Ethylene Glycol [1,2-Ethanediol]: 4 atoms
  Acetonitrile: 3 atoms
  Water: 1 atoms
  Diethyl Ether [Ether]: 5 atoms
  MTBE [tert-Butylmethylether]: 6 atoms
  Dimethyl Carbonate: 6 atoms
  tert-Butanol [2-Methylpropan-2-ol]: 5 atoms
  DMA [N,N-Dimethylacetamide]: 6 atoms
  2,2,2-Trifluoroethanol: 6 atoms
  Dihydrolevoglucosenone (Cyrene): 9 atoms
  Decanol: 11 atoms
  Butanone [MEK]: 5 atoms
  Ethyl Lactate: 8 atoms
  Methyl Propionate: 6 atoms
  THF [Tetrahydrofuran]: 5 atoms
  Water.Acetonitrile: 4 atoms
  Acetonitrile.Acetic Acid: 7 atoms
  Water.2,2,2-Trifluoroethanol: 7 atoms

Total solvents with graphs: 26


In [5]:
# Dual-Encoder GNN Model - FIXES THE MIXTURE BUG
class DualGNN(nn.Module):
    """GNN that properly handles both single solvents and mixtures.
    
    For mixtures: Encodes BOTH Solvent A and Solvent B, then combines
    with weighted pooling based on mixture percentage.
    """
    def __init__(self, in_channels=9, hidden_dim=64, out_dim=3):
        super().__init__()
        # Shared GNN encoder for both solvents
        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.bn3 = nn.BatchNorm1d(hidden_dim)
        
        # MLP head: graph_emb (hidden_dim) + process_feats (5)
        self.fc1 = nn.Linear(hidden_dim + 5, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(0.2)
        
    def encode_graph(self, data):
        """Encode a single molecular graph to a fixed-size embedding."""
        x, edge_index, batch = data.x.float(), data.edge_index, data.batch
        
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        
        # Graph-level readout
        return global_mean_pool(x, batch)
        
    def forward(self, graph_a, graph_b, pct_b, process_features):
        """Forward pass with proper mixture handling.
        
        Args:
            graph_a: Batched graph for Solvent A
            graph_b: Batched graph for Solvent B (same as A for single solvent)
            pct_b: Fraction of Solvent B (0 for single solvent)
            process_features: [T, RT, 1/T, ln(RT), interaction]
        """
        # Encode BOTH solvents
        emb_a = self.encode_graph(graph_a)
        emb_b = self.encode_graph(graph_b)
        
        # Weighted combination based on mixture percentage
        # pct_b shape: [batch_size]
        pct_b = pct_b.unsqueeze(1)  # [batch_size, 1]
        mixture_emb = (1 - pct_b) * emb_a + pct_b * emb_b
        
        # Concatenate with process features
        x = torch.cat([mixture_emb, process_features], dim=1)
        
        # MLP head
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))  # Yields are in [0, 1]
        
        return x

print('DualGNN defined')

DualGNN defined


In [6]:
# Custom Dataset for proper batching
class SolventDataset:
    def __init__(self, X, Y, data_type='single'):
        self.X = X.reset_index(drop=True)
        self.Y = Y.reset_index(drop=True) if Y is not None else None
        self.data_type = data_type
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        row = self.X.iloc[idx]
        
        # Process features
        T = row['Temperature']
        RT = row['Residence Time']
        T_K = T + 273.15
        inv_T = 1000.0 / T_K
        ln_RT = np.log(RT + 1e-6)
        interaction = inv_T * ln_RT
        process_feats = torch.tensor([T, RT, inv_T, ln_RT, interaction], dtype=torch.float)
        
        if self.data_type == 'single':
            solvent_name = row['SOLVENT NAME']
            graph_a = SOLVENT_GRAPHS[solvent_name].clone()
            graph_b = SOLVENT_GRAPHS[solvent_name].clone()  # Same as A for single
            pct_b = torch.tensor(0.0, dtype=torch.float)
        else:
            solvent_a = row['SOLVENT A NAME']
            solvent_b = row['SOLVENT B NAME']
            graph_a = SOLVENT_GRAPHS[solvent_a].clone()
            graph_b = SOLVENT_GRAPHS[solvent_b].clone()
            pct_b = torch.tensor(row['SolventB%'], dtype=torch.float)
        
        if self.Y is not None:
            y = torch.tensor(self.Y.iloc[idx].values, dtype=torch.float)
            return graph_a, graph_b, pct_b, process_feats, y
        else:
            return graph_a, graph_b, pct_b, process_feats

def collate_fn(batch):
    """Custom collate function for batching graphs."""
    if len(batch[0]) == 5:  # With labels
        graphs_a, graphs_b, pct_bs, process_feats, ys = zip(*batch)
        batch_a = Batch.from_data_list(graphs_a)
        batch_b = Batch.from_data_list(graphs_b)
        pct_b = torch.stack(pct_bs)
        process_feats = torch.stack(process_feats)
        y = torch.stack(ys)
        return batch_a, batch_b, pct_b, process_feats, y
    else:  # Without labels
        graphs_a, graphs_b, pct_bs, process_feats = zip(*batch)
        batch_a = Batch.from_data_list(graphs_a)
        batch_b = Batch.from_data_list(graphs_b)
        pct_b = torch.stack(pct_bs)
        process_feats = torch.stack(process_feats)
        return batch_a, batch_b, pct_b, process_feats

print('SolventDataset and collate_fn defined')

SolventDataset and collate_fn defined


In [7]:
# Model Wrapper
class DualGNNWrapper:
    def __init__(self, data='single', hidden_dim=64, lr=0.001, epochs=100, batch_size=32):
        self.data = data
        self.hidden_dim = hidden_dim
        self.lr = lr
        self.epochs = epochs
        self.batch_size = batch_size
        self.model = None
        self.train_mean = None
        
    def train_model(self, X, Y):
        """Train the dual-encoder GNN."""
        self.train_mean = Y.mean().values
        
        # Create dataset and dataloader
        dataset = SolventDataset(X, Y, data_type=self.data)
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn
        )
        
        # Create model
        self.model = DualGNN(in_channels=9, hidden_dim=self.hidden_dim, out_dim=3).to(device)
        
        # Training setup
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.MSELoss()
        
        # Training loop
        self.model.train()
        for epoch in range(self.epochs):
            total_loss = 0
            for batch in loader:
                batch_a, batch_b, pct_b, process_feats, y = batch
                batch_a = batch_a.to(device)
                batch_b = batch_b.to(device)
                pct_b = pct_b.to(device)
                process_feats = process_feats.to(device)
                y = y.to(device)
                
                optimizer.zero_grad()
                out = self.model(batch_a, batch_b, pct_b, process_feats)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
    
    def predict(self, X):
        """Make predictions."""
        dataset = SolventDataset(X, None, data_type=self.data)
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_fn
        )
        
        self.model.eval()
        all_preds = []
        with torch.no_grad():
            for batch in loader:
                batch_a, batch_b, pct_b, process_feats = batch
                batch_a = batch_a.to(device)
                batch_b = batch_b.to(device)
                pct_b = pct_b.to(device)
                process_feats = process_feats.to(device)
                
                out = self.model(batch_a, batch_b, pct_b, process_feats)
                all_preds.append(out.cpu())
        
        return torch.cat(all_preds, dim=0)

print('DualGNNWrapper defined')

DualGNNWrapper defined


In [8]:
# Quick test
print('Testing DualGNN model...')
X_single, Y_single = load_data('single_solvent')
print(f'Single solvent data: X={X_single.shape}, Y={Y_single.shape}')

# Test on a small subset
test_model = DualGNNWrapper(data='single', epochs=5, batch_size=32)
test_model.train_model(X_single.head(100), Y_single.head(100))
test_preds = test_model.predict(X_single.head(10))
print(f'Test predictions shape: {test_preds.shape}')
print(f'Test predictions (first 3):\n{test_preds[:3]}')

Testing DualGNN model...
Single solvent data: X=(656, 3), Y=(656, 3)


Test predictions shape: torch.Size([10, 3])
Test predictions (first 3):
tensor([[0.8865, 0.0260, 0.4064],
        [0.8678, 0.0271, 0.3886],
        [0.8394, 0.0282, 0.3650]])


In [9]:
# Cross-validation for single solvent data
print("="*60)
print("Cross-validation: Single Solvent Data (Leave-One-Out)")
print("="*60)

X_single, Y_single = load_data("single_solvent")
print(f"Single solvent data: X={X_single.shape}, Y={Y_single.shape}")

all_mse_single = []
for (train_X, train_Y), (test_X, test_Y) in tqdm.tqdm(generate_leave_one_out_splits(X_single, Y_single), total=24):
    model = DualGNNWrapper(data='single', epochs=100, batch_size=32, lr=0.001)
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mse = np.mean((preds - test_Y.values) ** 2)
    all_mse_single.append(mse)

mse_single = np.mean(all_mse_single)
print(f"\nSingle Solvent MSE: {mse_single:.6f} (+/- {np.std(all_mse_single):.6f})")

Cross-validation: Single Solvent Data (Leave-One-Out)
Single solvent data: X=(656, 3), Y=(656, 3)


  0%|          | 0/24 [00:00<?, ?it/s]

  4%|▍         | 1/24 [00:20<07:54, 20.62s/it]

  8%|▊         | 2/24 [00:41<07:33, 20.62s/it]

 12%|█▎        | 3/24 [01:00<06:58, 19.95s/it]

 17%|█▋        | 4/24 [01:19<06:34, 19.74s/it]

 21%|██        | 5/24 [01:40<06:20, 20.05s/it]

 25%|██▌       | 6/24 [02:00<06:04, 20.23s/it]

 29%|██▉       | 7/24 [02:21<05:43, 20.23s/it]

 33%|███▎      | 8/24 [02:41<05:22, 20.16s/it]

 38%|███▊      | 9/24 [03:01<05:02, 20.18s/it]

 42%|████▏     | 10/24 [03:21<04:43, 20.22s/it]

 46%|████▌     | 11/24 [03:42<04:23, 20.24s/it]

 50%|█████     | 12/24 [04:02<04:02, 20.25s/it]

 54%|█████▍    | 13/24 [04:22<03:42, 20.19s/it]

 58%|█████▊    | 14/24 [04:42<03:22, 20.30s/it]

 62%|██████▎   | 15/24 [05:03<03:02, 20.32s/it]

 67%|██████▋   | 16/24 [05:23<02:42, 20.31s/it]

 71%|███████   | 17/24 [05:44<02:23, 20.53s/it]

 75%|███████▌  | 18/24 [06:05<02:02, 20.49s/it]

 79%|███████▉  | 19/24 [06:25<01:41, 20.37s/it]

 83%|████████▎ | 20/24 [06:45<01:21, 20.38s/it]

 88%|████████▊ | 21/24 [07:05<01:01, 20.36s/it]

 92%|█████████▏| 22/24 [07:26<00:40, 20.39s/it]

 96%|█████████▌| 23/24 [07:46<00:20, 20.28s/it]

100%|██████████| 24/24 [08:06<00:00, 20.29s/it]

100%|██████████| 24/24 [08:06<00:00, 20.28s/it]


Single Solvent MSE: 0.025891 (+/- 0.025647)





In [10]:
# Cross-validation for full data (MIXTURE DATA - THE KEY TEST)
print("="*60)
print("Cross-validation: Full Data (Leave-One-Ramp-Out)")
print("THIS IS THE KEY TEST - Mixture handling should now work!")
print("="*60)

X_full, Y_full = load_data("full")
print(f"Full data: X={X_full.shape}, Y={Y_full.shape}")

all_mse_full = []
for (train_X, train_Y), (test_X, test_Y) in tqdm.tqdm(generate_leave_one_ramp_out_splits(X_full, Y_full), total=13):
    model = DualGNNWrapper(data='full', epochs=100, batch_size=32, lr=0.001)
    model.train_model(train_X, train_Y)
    preds = model.predict(test_X).numpy()
    mse = np.mean((preds - test_Y.values) ** 2)
    all_mse_full.append(mse)

mse_full = np.mean(all_mse_full)
print(f"\nFull Data MSE: {mse_full:.6f} (+/- {np.std(all_mse_full):.6f})")

Cross-validation: Full Data (Leave-One-Ramp-Out)
THIS IS THE KEY TEST - Mixture handling should now work!
Full data: X=(1227, 5), Y=(1227, 3)


  0%|          | 0/13 [00:00<?, ?it/s]

  8%|▊         | 1/13 [00:36<07:14, 36.17s/it]

 15%|█▌        | 2/13 [01:12<06:40, 36.40s/it]

 23%|██▎       | 3/13 [01:49<06:06, 36.66s/it]

 31%|███       | 4/13 [02:25<05:28, 36.49s/it]

 38%|███▊      | 5/13 [03:02<04:51, 36.42s/it]

 46%|████▌     | 6/13 [03:37<04:13, 36.20s/it]

 54%|█████▍    | 7/13 [04:14<03:36, 36.14s/it]

 62%|██████▏   | 8/13 [04:50<03:00, 36.14s/it]

 69%|██████▉   | 9/13 [05:26<02:24, 36.11s/it]

 77%|███████▋  | 10/13 [06:05<01:50, 36.99s/it]

 85%|████████▍ | 11/13 [06:44<01:15, 37.67s/it]

 92%|█████████▏| 12/13 [07:23<00:38, 38.14s/it]

100%|██████████| 13/13 [08:02<00:00, 38.40s/it]

100%|██████████| 13/13 [08:02<00:00, 37.12s/it]


Full Data MSE: 0.023685 (+/- 0.018255)





In [11]:
# Calculate overall MSE
N_single = len(X_single)
N_full = len(X_full)
N_total = N_single + N_full

overall_mse = (mse_single * N_single + mse_full * N_full) / N_total

print("="*60)
print("SUMMARY")
print("="*60)
print(f"\nDual-Encoder GNN (Fixed Mixture Handling):")
print(f"  Single Solvent MSE: {mse_single:.6f}")
print(f"  Full Data MSE: {mse_full:.6f}")
print(f"  Overall MSE: {overall_mse:.6f}")

print(f"\nComparison:")
print(f"  Previous GNN (broken mixture): 0.026222")
print(f"  Best tabular (GP+MLP+LGBM): 0.008298")
print(f"  This GNN vs Previous: {(overall_mse - 0.026222) / 0.026222 * 100:.2f}%")
print(f"  This GNN vs Best tabular: {(overall_mse - 0.008298) / 0.008298 * 100:.2f}%")

# Expected LB based on CV-LB relationship
expected_lb = 4.31 * overall_mse + 0.0525
print(f"\nExpected LB (based on CV-LB line): {expected_lb:.4f}")
print(f"Target LB: 0.0347")

SUMMARY

Dual-Encoder GNN (Fixed Mixture Handling):
  Single Solvent MSE: 0.025891
  Full Data MSE: 0.023685
  Overall MSE: 0.024454

Comparison:
  Previous GNN (broken mixture): 0.026222
  Best tabular (GP+MLP+LGBM): 0.008298
  This GNN vs Previous: -6.74%
  This GNN vs Best tabular: 194.69%

Expected LB (based on CV-LB line): 0.1579
Target LB: 0.0347


In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

import tqdm

X, Y = load_data("single_solvent")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = DualGNNWrapper(data='single')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_single_solvent = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE THIRD LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

X, Y = load_data("full")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator)):
    (train_X, train_Y), (test_X, test_Y) = split

    model = DualGNNWrapper(data='full')  # CHANGE THIS LINE ONLY
    model.train_model(train_X, train_Y)

    predictions = model.predict(test_X)  # Shape: [N, 3]

    # Move to CPU and convert to numpy
    predictions_np = predictions.detach().cpu().numpy()

    # Add metadata and flatten to long format
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

# Save final submission
submission_full_data = pd.DataFrame(all_predictions)

########### DO NOT CHANGE ANYTHING IN THIS CELL OTHER THAN THE MODEL #################
########### THIS MUST BE THE SECOND LAST CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

In [None]:
########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################

submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"
submission.to_csv("/home/submission/submission.csv", index=True)

########### DO NOT CHANGE ANYTHING IN THIS CELL #################
########### THIS MUST BE THE FINAL CELL IN YOUR NOTEBOOK FOR A VALID SUBMISSION #################