# Experiment 088: ChemBERTa Pre-trained Molecular Embeddings

**Rationale**: The GNN benchmark achieved 0.0039 CV using pre-trained representations. Our GNN experiments without pre-training all failed (CV 0.018-0.020). The key missing ingredient is pre-training on large molecular datasets.

**Approach**:
1. Use ChemBERTa (pre-trained on 77M molecules from PubChem)
2. Extract embeddings for each solvent SMILES (768-dim)
3. Use embeddings as features instead of Spange descriptors
4. Train simple MLP on ChemBERTa embeddings + T + RT

**Key hypothesis**: Pre-trained embeddings capture molecular knowledge that generalizes to unseen solvents, potentially reducing the CV-LB intercept.

In [None]:
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import warnings
warnings.filterwarnings('ignore')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Load ChemBERTa model
from transformers import AutoTokenizer, AutoModel

print("Loading ChemBERTa model...")
model_name = "seyonec/ChemBERTa-zinc-base-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
chemberta = AutoModel.from_pretrained(model_name)
chemberta = chemberta.to(device)
chemberta.eval()
print(f"ChemBERTa loaded: {model_name}")
print(f"Embedding dimension: {chemberta.config.hidden_size}")

In [None]:
# Load data
def load_data(data_type):
    if data_type == "single_solvent":
        df = pd.read_csv('/home/data/catechol_single_solvent_yields.csv')
        X = df[['Residence Time', 'Temperature', 'SOLVENT NAME']]
        Y = df[['SM', 'Product 2', 'Product 3']]
    elif data_type == "full":
        df = pd.read_csv('/home/data/catechol_full_data_yields.csv')
        X = df[['Residence Time', 'Temperature', 'SOLVENT A NAME', 'SOLVENT B NAME', 'SolventB%']]
        Y = df[['SM', 'Product 2', 'Product 3']]
    return X, Y

# Load SMILES lookup
smiles_df = pd.read_csv('/home/data/smiles_lookup.csv')
smiles_dict = dict(zip(smiles_df['SOLVENT NAME'], smiles_df['solvent smiles']))
print(f"Loaded {len(smiles_dict)} SMILES")

In [None]:
# Extract ChemBERTa embeddings for all solvents
def get_chemberta_embedding(smiles):
    """Extract ChemBERTa embedding for a SMILES string."""
    with torch.no_grad():
        inputs = tokenizer(smiles, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = chemberta(**inputs)
        # Use [CLS] token embedding (first token)
        embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy().flatten()
    return embedding

# Pre-compute embeddings for all solvents
print("Computing ChemBERTa embeddings for all solvents...")
chemberta_embeddings = {}
for name, smiles in smiles_dict.items():
    emb = get_chemberta_embedding(smiles)
    chemberta_embeddings[name] = emb
    print(f"{name}: {smiles[:30]}... -> embedding shape {emb.shape}")

print(f"\nTotal: {len(chemberta_embeddings)} solvent embeddings")
print(f"Embedding dimension: {list(chemberta_embeddings.values())[0].shape[0]}")

In [None]:
# Official CV split functions (DO NOT MODIFY)
from typing import Any, Generator

def generate_leave_one_out_splits(
    X: pd.DataFrame, Y: pd.DataFrame
) -> Generator[
    tuple[tuple[pd.DataFrame, pd.DataFrame], tuple[pd.DataFrame, pd.DataFrame]],
    Any,
    None,
]:
    for solvent in X["SOLVENT NAME"].unique():
        train_mask = X["SOLVENT NAME"] != solvent
        test_mask = X["SOLVENT NAME"] == solvent
        yield (
            (X[train_mask], Y[train_mask]),
            (X[test_mask], Y[test_mask]),
        )

def generate_leave_one_ramp_out_splits(
    X: pd.DataFrame, Y: pd.DataFrame
) -> Generator[
    tuple[tuple[pd.DataFrame, pd.DataFrame], tuple[pd.DataFrame, pd.DataFrame]],
    Any,
    None,
]:
    ramps = X["SOLVENT A NAME"].astype(str) + "_" + X["SOLVENT B NAME"].astype(str)
    for ramp in ramps.unique():
        train_mask = ramps != ramp
        test_mask = ramps == ramp
        yield (
            (X[train_mask], Y[train_mask]),
            (X[test_mask], Y[test_mask]),
        )

print("CV split functions defined")

In [None]:
# MLP Model using ChemBERTa embeddings
class ChemBERTaMLPModel(nn.Module):
    def __init__(self, emb_dim=768, hidden_dim=256, out_dim=3):
        super().__init__()
        # Input: ChemBERTa embedding + T + RT
        self.net = nn.Sequential(
            nn.Linear(emb_dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, out_dim),
            nn.Sigmoid()
        )
    
    def forward(self, emb, T, RT):
        x = torch.cat([emb, T.unsqueeze(1), RT.unsqueeze(1)], dim=1)
        return self.net(x)

print("ChemBERTaMLPModel defined")

In [None]:
# MLP Model for MIXTURES using ChemBERTa embeddings
class ChemBERTaMixtureMLPModel(nn.Module):
    def __init__(self, emb_dim=768, hidden_dim=256, out_dim=3):
        super().__init__()
        # Input: ChemBERTa_A + ChemBERTa_B + mix_frac + T + RT
        self.net = nn.Sequential(
            nn.Linear(emb_dim * 2 + 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, out_dim),
            nn.Sigmoid()
        )
    
    def forward(self, emb_A, emb_B, mix_frac, T, RT):
        x = torch.cat([emb_A, emb_B, mix_frac.unsqueeze(1), T.unsqueeze(1), RT.unsqueeze(1)], dim=1)
        return self.net(x)

print("ChemBERTaMixtureMLPModel defined")

In [None]:
# Model wrapper
class ChemBERTaModel:
    def __init__(self, data='single', hidden_dim=256, num_epochs=300, lr=1e-3):
        self.data_type = data
        self.mixed = (data == 'full')
        self.hidden_dim = hidden_dim
        self.num_epochs = num_epochs
        self.lr = lr
        self.model = None
        self.emb_dim = 768
        
    def train_model(self, train_X, train_Y):
        if self.mixed:
            self._train_mixed(train_X, train_Y)
        else:
            self._train_single(train_X, train_Y)
    
    def _train_single(self, train_X, train_Y):
        embeddings = []
        temps = []
        rts = []
        targets = []
        
        for i in range(len(train_X)):
            row = train_X.iloc[i]
            solvent_name = row['SOLVENT NAME']
            
            if solvent_name not in chemberta_embeddings:
                continue
            
            embeddings.append(chemberta_embeddings[solvent_name])
            temps.append(row['Temperature'])
            rts.append(row['Residence Time'])
            targets.append(train_Y.iloc[i].values)
        
        # Normalize
        temps = np.array(temps)
        rts = np.array(rts)
        self.temp_mean, self.temp_std = temps.mean(), temps.std() + 1e-8
        self.rt_mean, self.rt_std = rts.mean(), rts.std() + 1e-8
        temps = (temps - self.temp_mean) / self.temp_std
        rts = (rts - self.rt_mean) / self.rt_std
        
        embeddings = torch.tensor(np.array(embeddings), dtype=torch.float).to(device)
        temps = torch.tensor(temps, dtype=torch.float).to(device)
        rts = torch.tensor(rts, dtype=torch.float).to(device)
        targets = torch.tensor(np.array(targets), dtype=torch.float).to(device)
        
        # Model
        self.model = ChemBERTaMLPModel(emb_dim=self.emb_dim, hidden_dim=self.hidden_dim, out_dim=3).to(device)
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=1e-6)
        
        self.model.train()
        batch_size = 32
        n_samples = len(embeddings)
        
        for epoch in range(self.num_epochs):
            indices = np.random.permutation(n_samples)
            total_loss = 0
            
            for start in range(0, n_samples, batch_size):
                end = min(start + batch_size, n_samples)
                batch_idx = indices[start:end]
                
                batch_emb = embeddings[batch_idx]
                batch_T = temps[batch_idx]
                batch_RT = rts[batch_idx]
                batch_targets = targets[batch_idx]
                
                optimizer.zero_grad()
                outputs = self.model(batch_emb, batch_T, batch_RT)
                loss = F.mse_loss(outputs, batch_targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                
                total_loss += loss.item() * len(batch_idx)
            
            scheduler.step()
        
        self.model.eval()
    
    def _train_mixed(self, train_X, train_Y):
        emb_A_list = []
        emb_B_list = []
        mix_fracs = []
        temps = []
        rts = []
        targets = []
        
        for i in range(len(train_X)):
            row = train_X.iloc[i]
            solvent_A = row['SOLVENT A NAME']
            solvent_B = row['SOLVENT B NAME']
            
            if solvent_A not in chemberta_embeddings or solvent_B not in chemberta_embeddings:
                continue
            
            emb_A_list.append(chemberta_embeddings[solvent_A])
            emb_B_list.append(chemberta_embeddings[solvent_B])
            mix_fracs.append(row['SolventB%'] / 100.0)
            temps.append(row['Temperature'])
            rts.append(row['Residence Time'])
            targets.append(train_Y.iloc[i].values)
        
        # Normalize
        temps = np.array(temps)
        rts = np.array(rts)
        mix_fracs = np.array(mix_fracs)
        self.temp_mean, self.temp_std = temps.mean(), temps.std() + 1e-8
        self.rt_mean, self.rt_std = rts.mean(), rts.std() + 1e-8
        temps = (temps - self.temp_mean) / self.temp_std
        rts = (rts - self.rt_mean) / self.rt_std
        
        emb_A = torch.tensor(np.array(emb_A_list), dtype=torch.float).to(device)
        emb_B = torch.tensor(np.array(emb_B_list), dtype=torch.float).to(device)
        temps = torch.tensor(temps, dtype=torch.float).to(device)
        rts = torch.tensor(rts, dtype=torch.float).to(device)
        mix_fracs = torch.tensor(mix_fracs, dtype=torch.float).to(device)
        targets = torch.tensor(np.array(targets), dtype=torch.float).to(device)
        
        # Model
        self.model = ChemBERTaMixtureMLPModel(emb_dim=self.emb_dim, hidden_dim=self.hidden_dim, out_dim=3).to(device)
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs, eta_min=1e-6)
        
        self.model.train()
        batch_size = 32
        n_samples = len(emb_A)
        
        for epoch in range(self.num_epochs):
            indices = np.random.permutation(n_samples)
            total_loss = 0
            
            for start in range(0, n_samples, batch_size):
                end = min(start + batch_size, n_samples)
                batch_idx = indices[start:end]
                
                batch_emb_A = emb_A[batch_idx]
                batch_emb_B = emb_B[batch_idx]
                batch_mix = mix_fracs[batch_idx]
                batch_T = temps[batch_idx]
                batch_RT = rts[batch_idx]
                batch_targets = targets[batch_idx]
                
                optimizer.zero_grad()
                outputs = self.model(batch_emb_A, batch_emb_B, batch_mix, batch_T, batch_RT)
                loss = F.mse_loss(outputs, batch_targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                
                total_loss += loss.item() * len(batch_idx)
            
            scheduler.step()
        
        self.model.eval()
    
    def predict(self, test_X):
        self.model.eval()
        with torch.no_grad():
            if self.mixed:
                return self._predict_mixed(test_X)
            else:
                return self._predict_single(test_X)
    
    def _predict_single(self, test_X):
        embeddings = []
        temps = []
        rts = []
        
        for i in range(len(test_X)):
            row = test_X.iloc[i]
            solvent_name = row['SOLVENT NAME']
            
            if solvent_name not in chemberta_embeddings:
                embeddings.append(list(chemberta_embeddings.values())[0])
            else:
                embeddings.append(chemberta_embeddings[solvent_name])
            
            temps.append(row['Temperature'])
            rts.append(row['Residence Time'])
        
        temps = (np.array(temps) - self.temp_mean) / self.temp_std
        rts = (np.array(rts) - self.rt_mean) / self.rt_std
        
        embeddings = torch.tensor(np.array(embeddings), dtype=torch.float).to(device)
        temps = torch.tensor(temps, dtype=torch.float).to(device)
        rts = torch.tensor(rts, dtype=torch.float).to(device)
        
        outputs = self.model(embeddings, temps, rts)
        return outputs
    
    def _predict_mixed(self, test_X):
        emb_A_list = []
        emb_B_list = []
        mix_fracs = []
        temps = []
        rts = []
        
        for i in range(len(test_X)):
            row = test_X.iloc[i]
            solvent_A = row['SOLVENT A NAME']
            solvent_B = row['SOLVENT B NAME']
            
            if solvent_A not in chemberta_embeddings:
                emb_A_list.append(list(chemberta_embeddings.values())[0])
            else:
                emb_A_list.append(chemberta_embeddings[solvent_A])
            
            if solvent_B not in chemberta_embeddings:
                emb_B_list.append(list(chemberta_embeddings.values())[0])
            else:
                emb_B_list.append(chemberta_embeddings[solvent_B])
            
            mix_fracs.append(row['SolventB%'] / 100.0)
            temps.append(row['Temperature'])
            rts.append(row['Residence Time'])
        
        temps = (np.array(temps) - self.temp_mean) / self.temp_std
        rts = (np.array(rts) - self.rt_mean) / self.rt_std
        
        emb_A = torch.tensor(np.array(emb_A_list), dtype=torch.float).to(device)
        emb_B = torch.tensor(np.array(emb_B_list), dtype=torch.float).to(device)
        temps = torch.tensor(temps, dtype=torch.float).to(device)
        rts = torch.tensor(rts, dtype=torch.float).to(device)
        mix_fracs = torch.tensor(np.array(mix_fracs), dtype=torch.float).to(device)
        
        outputs = self.model(emb_A, emb_B, mix_fracs, temps, rts)
        return outputs

print("ChemBERTaModel wrapper defined")

In [None]:
# Run CV for single solvent data
import tqdm

X, Y = load_data("single_solvent")
print(f"Single solvent data: {len(X)} samples, {len(X['SOLVENT NAME'].unique())} solvents")

split_generator = generate_leave_one_out_splits(X, Y)
all_predictions = []
fold_mses = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=24):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = ChemBERTaModel(data='single', hidden_dim=256, num_epochs=300, lr=1e-3)
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X)
    predictions_np = predictions.detach().cpu().numpy()
    
    fold_mse = np.mean((predictions_np - test_Y.values) ** 2)
    fold_mses.append(fold_mse)
    
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 0,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

submission_single_solvent = pd.DataFrame(all_predictions)
print(f"\nSingle solvent predictions: {len(submission_single_solvent)} rows")
print(f"Mean fold MSE: {np.mean(fold_mses):.6f} ± {np.std(fold_mses):.6f}")

In [None]:
# Run CV for full (mixture) data
X, Y = load_data("full")
print(f"Full data: {len(X)} samples")

split_generator = generate_leave_one_ramp_out_splits(X, Y)
all_predictions = []
fold_mses = []

for fold_idx, split in tqdm.tqdm(enumerate(split_generator), total=13):
    (train_X, train_Y), (test_X, test_Y) = split
    
    model = ChemBERTaModel(data='full', hidden_dim=256, num_epochs=300, lr=1e-3)
    model.train_model(train_X, train_Y)
    
    predictions = model.predict(test_X)
    predictions_np = predictions.detach().cpu().numpy()
    
    fold_mse = np.mean((predictions_np - test_Y.values) ** 2)
    fold_mses.append(fold_mse)
    
    for row_idx, row in enumerate(predictions_np):
        all_predictions.append({
            "task": 1,
            "fold": fold_idx,
            "row": row_idx,
            "target_1": row[0],
            "target_2": row[1],
            "target_3": row[2]
        })

submission_full_data = pd.DataFrame(all_predictions)
print(f"\nFull data predictions: {len(submission_full_data)} rows")
print(f"Mean fold MSE: {np.mean(fold_mses):.6f} ± {np.std(fold_mses):.6f}")

In [None]:
# Combine and save submission
submission = pd.concat([submission_single_solvent, submission_full_data])
submission = submission.reset_index()
submission.index.name = "id"

print(f"Submission shape: {submission.shape}")

submission.to_csv("/home/submission/submission.csv", index=True)
print(f"\nSubmission saved to /home/submission/submission.csv")

submission_check = pd.read_csv("/home/submission/submission.csv")
print(f"\nSubmission rows: {len(submission_check)}")

target_cols = ['target_1', 'target_2', 'target_3']
for col in target_cols:
    print(f"{col}: min={submission_check[col].min():.4f}, max={submission_check[col].max():.4f}")

In [None]:
# Calculate overall CV score
print("="*50)
print("EXPERIMENT 088 COMPLETE")
print("="*50)

print(f"\nKey techniques:")
print("1. ChemBERTa pre-trained embeddings (768-dim)")
print("2. Pre-trained on 77M molecules from PubChem")
print("3. Simple MLP on embeddings + T + RT")
print("4. 300 epochs with cosine annealing LR")
print("\nThis uses pre-trained molecular knowledge that should generalize to unseen solvents.")