# Experiment 050: LISA Selective Augmentation + REx Loss

**Hypothesis**: Domain adaptation techniques (LISA + REx) can learn domain-invariant predictors that generalize better to unseen solvents.

**Based on**: 
- ICML 2022 paper "Improving Out-of-Distribution Robustness via Selective Augmentation" (LISA)
- "Out-of-Distribution Generalization via Risk Extrapolation" (REx)

**Implementation**:
1. LISA: Interpolate samples with similar labels but different solvents
2. REx: Penalize variance of per-solvent losses
3. Use the exp_030 architecture (GP + MLP + LGBM)

**Success criteria**:
- Not just better CV, but better worst-case performance
- Lower variance of per-solvent errors
- If CV is similar but per-solvent variance is lower, this could change the CV-LB relationship

In [1]:
import sys
sys.path.insert(0, '/home/code/experiments/049_manual_ood_handling')

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel, Matern
import lightgbm as lgb
import warnings
warnings.filterwarnings('ignore')

# Load data
from utils_local import load_data, load_features, generate_leave_one_out_splits, generate_leave_one_ramp_out_splits

print("Loading data...")
X_single_raw, Y_single = load_data("single_solvent")
X_full_raw, Y_full = load_data("full")

print(f"Single solvent: {X_single_raw.shape}, Mixtures: {X_full_raw.shape}")

# Load features
spange = load_features("spange_descriptors")
drfp = load_features("drfps_catechol")
print(f"Spange: {spange.shape}, DRFP: {drfp.shape}")

Loading data...
Single solvent: (656, 3), Mixtures: (1227, 5)
Spange: (26, 13), DRFP: (24, 2048)


In [2]:
# Prepare datasets with features
def prepare_single_solvent_dataset(X_raw, spange, drfp):
    """Prepare single solvent dataset with all features"""
    solvent_name = X_raw['SOLVENT NAME'].values
    spange_features = spange.loc[solvent_name].values
    drfp_features = drfp.loc[solvent_name].values
    time = X_raw['Residence Time'].values
    temp = X_raw['Temperature'].values
    
    spange_cols = spange.columns.tolist()
    drfp_cols = [f'DRFP_{i}' for i in range(drfp.shape[1])]
    
    df = pd.DataFrame(spange_features, columns=spange_cols)
    df_drfp = pd.DataFrame(drfp_features, columns=drfp_cols)
    df = pd.concat([df, df_drfp], axis=1)
    df['TEMPERATURE'] = temp
    df['TIME'] = time
    df['SOLVENT NAME'] = solvent_name
    
    return df

def prepare_mixture_dataset(X_raw, spange, drfp):
    """Prepare mixture dataset with all features"""
    solvent_a = X_raw['SOLVENT A NAME'].values
    solvent_b = X_raw['SOLVENT B NAME'].values
    solvent_b_pct = X_raw['SolventB%'].values / 100.0
    
    spange_a = spange.loc[solvent_a].values
    spange_b = spange.loc[solvent_b].values
    spange_mix = (1 - solvent_b_pct[:, None]) * spange_a + solvent_b_pct[:, None] * spange_b
    
    drfp_a = drfp.loc[solvent_a].values
    drfp_b = drfp.loc[solvent_b].values
    drfp_mix = (1 - solvent_b_pct[:, None]) * drfp_a + solvent_b_pct[:, None] * drfp_b
    
    solvent_name = [f"{a}.{b}" for a, b in zip(solvent_a, solvent_b)]
    time = X_raw['Residence Time'].values
    temp = X_raw['Temperature'].values
    
    spange_cols = spange.columns.tolist()
    drfp_cols = [f'DRFP_{i}' for i in range(drfp.shape[1])]
    
    df = pd.DataFrame(spange_mix, columns=spange_cols)
    df_drfp = pd.DataFrame(drfp_mix, columns=drfp_cols)
    df = pd.concat([df, df_drfp], axis=1)
    df['TEMPERATURE'] = temp
    df['TIME'] = time
    df['SOLVENT NAME'] = solvent_name
    df['SOLVENT A NAME'] = solvent_a
    df['SOLVENT B NAME'] = solvent_b
    df['SolventB%'] = X_raw['SolventB%'].values
    
    return df

X_single = prepare_single_solvent_dataset(X_single_raw, spange, drfp)
X_mix = prepare_mixture_dataset(X_full_raw, spange, drfp)

print(f"Single solvent dataset: {X_single.shape}")
print(f"Mixture dataset: {X_mix.shape}")

Single solvent dataset: (656, 2064)
Mixture dataset: (1227, 2067)


In [3]:
# Feature extraction functions
def get_spange_features(X_data):
    spange_cols = ['dielectric constant', 'ET(30)', 'alpha', 'beta', 'pi*', 
                   'SA', 'SB', 'SP', 'SdP', 'N', 'n', 'f(n)', 'delta']
    return X_data[spange_cols].values

def get_arrhenius_features(X_data):
    T = X_data['TEMPERATURE'].values
    t = X_data['TIME'].values
    T_kelvin = T + 273.15
    inv_T = 1.0 / T_kelvin
    ln_t = np.log(t + 1e-6)
    interaction = inv_T * ln_t
    return np.column_stack([inv_T, ln_t, interaction, T, t])

def prepare_features(X_data, drfp_mask=None, include_drfp=True):
    spange = get_spange_features(X_data)
    arrhenius = get_arrhenius_features(X_data)
    
    if include_drfp:
        drfp_cols = [col for col in X_data.columns if col.startswith('DRFP_')]
        drfp_data = X_data[drfp_cols].values
        if drfp_mask is not None:
            drfp_data = drfp_data[:, drfp_mask]
        features = np.hstack([spange, drfp_data, arrhenius])
    else:
        features = np.hstack([spange, arrhenius])
    
    return features

print("Feature extraction functions defined")

Feature extraction functions defined


In [4]:
# LISA Augmentation - Fixed to work with prepared features
def lisa_augmentation(X_scaled, Y_values, solvents, alpha=0.5, n_augment=None):
    """
    LISA: Interpolate samples with similar labels but different solvents.
    This encourages the model to learn domain-invariant features.
    
    Args:
        X_scaled: Already scaled feature matrix (numpy array)
        Y_values: Target values (numpy array)
        solvents: Solvent names (numpy array)
        alpha: Beta distribution parameter for mixup
        n_augment: Number of augmented samples to create
    """
    if n_augment is None:
        n_augment = len(X_scaled)  # Augment same number as original
    
    augmented_X = []
    augmented_Y = []
    augmented_solvents = []
    
    for _ in range(n_augment):
        # Randomly select a sample
        i = np.random.randint(len(X_scaled))
        
        # Find samples with similar labels but different solvents
        label_diff = np.abs(Y_values - Y_values[i]).sum(axis=1)
        same_label_mask = label_diff < 0.15  # Threshold for "similar" labels
        diff_solvent_mask = solvents != solvents[i]
        candidates = np.where(same_label_mask & diff_solvent_mask)[0]
        
        if len(candidates) > 0:
            j = np.random.choice(candidates)
            lam = np.random.beta(alpha, alpha)
            
            # Interpolate features and labels
            x_mix = lam * X_scaled[i] + (1 - lam) * X_scaled[j]
            y_mix = lam * Y_values[i] + (1 - lam) * Y_values[j]
            
            augmented_X.append(x_mix)
            augmented_Y.append(y_mix)
            augmented_solvents.append('MIXED')
    
    if len(augmented_X) > 0:
        augmented_X = np.array(augmented_X)
        augmented_Y = np.array(augmented_Y)
        augmented_solvents = np.array(augmented_solvents)
        return augmented_X, augmented_Y, augmented_solvents
    else:
        return None, None, None

print("LISA augmentation function defined (fixed)")

LISA augmentation function defined (fixed)


In [5]:
# MLP Model with REx Loss
class MLPModelREx(nn.Module):
    def __init__(self, input_dim, hidden_dims=[128, 64], dropout=0.2):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, h_dim),
                nn.BatchNorm1d(h_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, 3))
        layers.append(nn.Sigmoid())
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

def train_mlp_rex(X_train, Y_train, solvents, input_dim, epochs=200, lr=5e-4, 
                  weight_decay=1e-4, hidden_dims=[128, 64], rex_beta=1.0):
    """
    Train MLP with REx loss: penalize variance of per-solvent losses.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MLPModelREx(input_dim, hidden_dims=hidden_dims).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=20)
    
    X_tensor = torch.FloatTensor(X_train).to(device)
    Y_tensor = torch.FloatTensor(Y_train).to(device)
    
    # Get unique solvents and their indices
    unique_solvents = np.unique(solvents)
    solvent_indices = {s: np.where(solvents == s)[0] for s in unique_solvents}
    
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        pred = model(X_tensor)
        
        # Compute per-solvent losses
        per_solvent_losses = []
        for solvent, indices in solvent_indices.items():
            if len(indices) > 0:
                solvent_loss = F.mse_loss(pred[indices], Y_tensor[indices])
                per_solvent_losses.append(solvent_loss)
        
        if len(per_solvent_losses) > 1:
            per_solvent_losses = torch.stack(per_solvent_losses)
            mean_loss = torch.mean(per_solvent_losses)
            var_loss = torch.var(per_solvent_losses)
            
            # REx loss = mean + beta * variance
            loss = mean_loss + rex_beta * var_loss
        else:
            loss = F.mse_loss(pred, Y_tensor)
        
        loss.backward()
        optimizer.step()
        scheduler.step(loss)
    
    return model

print("MLP with REx loss defined")

MLP with REx loss defined


In [6]:
# LISA + REx Model - Fixed
class LISARExModel:
    """
    Model that uses LISA augmentation and REx loss for domain-invariant learning.
    """
    def __init__(self, gp_weight=0.15, mlp_weight=0.55, lgbm_weight=0.3, 
                 lisa_alpha=0.5, rex_beta=1.0):
        self.gp_weight = gp_weight
        self.mlp_weight = mlp_weight
        self.lgbm_weight = lgbm_weight
        self.lisa_alpha = lisa_alpha
        self.rex_beta = rex_beta
        
        self.scaler = StandardScaler()
        self.gp_models = []
        self.mlp_models = []
        self.lgbm_models = []
        
        self.drfp_mask = None
        self.input_dim = None
    
    def fit(self, X_train, Y_train):
        """Train with LISA augmentation and REx loss"""
        # Get DRFP mask
        drfp_cols = [col for col in X_train.columns if col.startswith('DRFP_')]
        drfp_data = X_train[drfp_cols].values
        self.drfp_mask = drfp_data.var(axis=0) > 0
        
        # Prepare features
        X_features = prepare_features(X_train, self.drfp_mask, include_drfp=True)
        self.input_dim = X_features.shape[1]
        X_scaled = self.scaler.fit_transform(X_features)
        
        Y_values = Y_train.values
        solvents = X_train['SOLVENT NAME'].values
        
        # Apply LISA augmentation on scaled features
        aug_X, aug_Y, aug_solvents = lisa_augmentation(X_scaled, Y_values, solvents, 
                                                        alpha=self.lisa_alpha)
        if aug_X is not None:
            # Combine original and augmented data
            X_combined = np.vstack([X_scaled, aug_X])
            Y_combined = np.vstack([Y_values, aug_Y])
            solvents_combined = np.concatenate([solvents, aug_solvents])
        else:
            X_combined = X_scaled
            Y_combined = Y_values
            solvents_combined = solvents
        
        # Train GP (on original data only, for speed)
        n_gp = min(200, len(X_scaled))
        idx_gp = np.random.choice(len(X_scaled), n_gp, replace=False)
        for i in range(3):
            kernel = Matern(nu=2.5) + WhiteKernel(noise_level=0.1)
            gp = GaussianProcessRegressor(kernel=kernel, alpha=0.1, n_restarts_optimizer=2)
            gp.fit(X_scaled[idx_gp], Y_values[idx_gp, i])
            self.gp_models.append(gp)
        
        # Train MLP with REx loss (on combined data)
        for _ in range(3):
            mlp = train_mlp_rex(X_combined, Y_combined, solvents_combined, 
                               self.input_dim, epochs=200, hidden_dims=[128, 64],
                               rex_beta=self.rex_beta)
            self.mlp_models.append(mlp)
        
        # Train LightGBM (on combined data)
        lgbm_params = {
            'objective': 'regression',
            'metric': 'mse',
            'learning_rate': 0.03,
            'max_depth': 6,
            'num_leaves': 31,
            'reg_alpha': 0.1,
            'reg_lambda': 0.1,
            'verbose': -1,
            'n_estimators': 500
        }
        for i in range(3):
            model = lgb.LGBMRegressor(**lgbm_params)
            model.fit(X_combined, Y_combined[:, i])
            self.lgbm_models.append(model)
        
        return self
    
    def predict(self, X_test):
        """Predict using ensemble"""
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        X_features = prepare_features(X_test, self.drfp_mask, include_drfp=True)
        X_scaled = self.scaler.transform(X_features)
        
        # GP predictions
        gp_preds = np.zeros((len(X_test), 3))
        for i, gp in enumerate(self.gp_models):
            gp_preds[:, i] = gp.predict(X_scaled)
        gp_preds = np.clip(gp_preds, 0, 1)
        
        # MLP predictions
        mlp_preds = []
        X_tensor = torch.FloatTensor(X_scaled).to(device)
        for mlp in self.mlp_models:
            mlp.eval()
            with torch.no_grad():
                pred = mlp(X_tensor).cpu().numpy()
            mlp_preds.append(pred)
        mlp_pred = np.mean(mlp_preds, axis=0)
        
        # LightGBM predictions
        lgbm_pred = np.zeros((len(X_test), 3))
        for i, model in enumerate(self.lgbm_models):
            lgbm_pred[:, i] = model.predict(X_scaled)
        lgbm_pred = np.clip(lgbm_pred, 0, 1)
        
        # Ensemble
        final_pred = self.gp_weight * gp_preds + self.mlp_weight * mlp_pred + self.lgbm_weight * lgbm_pred
        
        return np.clip(final_pred, 0, 1)

print("LISARExModel defined (fixed)")

LISARExModel defined (fixed)


In [7]:
# Run CV for single solvents with LISA + REx
print("Running Single Solvent CV with LISA + REx Model...")
print("="*60)

splits = list(generate_leave_one_out_splits(X_single, Y_single))
print(f"Number of folds: {len(splits)}")

solvent_errors_lisa = {}
all_preds_lisa = []
all_true_lisa = []

for fold_idx, (train_idx, test_idx) in enumerate(splits):
    X_train = X_single.iloc[train_idx]
    Y_train = Y_single.iloc[train_idx]
    X_test = X_single.iloc[test_idx]
    Y_test = Y_single.iloc[test_idx]
    
    test_solvent = X_test['SOLVENT NAME'].iloc[0]
    
    # Train model with LISA + REx
    model = LISARExModel(lisa_alpha=0.5, rex_beta=1.0)
    model.fit(X_train, Y_train)
    
    # Predict
    preds = model.predict(X_test)
    
    # Calculate MSE
    mse = np.mean((preds - Y_test.values) ** 2)
    solvent_errors_lisa[test_solvent] = mse
    
    all_preds_lisa.append(preds)
    all_true_lisa.append(Y_test.values)
    
    print(f"Fold {fold_idx+1:2d}: {test_solvent:45s} MSE = {mse:.6f}")

all_preds_lisa = np.vstack(all_preds_lisa)
all_true_lisa = np.vstack(all_true_lisa)
single_mse_lisa = np.mean((all_preds_lisa - all_true_lisa) ** 2)
single_std_lisa = np.std([solvent_errors_lisa[s] for s in solvent_errors_lisa])

print(f"\nLISA + REx Single Solvent CV MSE: {single_mse_lisa:.6f} +/- {single_std_lisa:.6f}")

Running Single Solvent CV with LISA + REx Model...


Number of folds: 24


Fold  1: 1,1,1,3,3,3-Hexafluoropropan-2-ol             MSE = 0.047835


Fold  2: 2,2,2-Trifluoroethanol                        MSE = 0.019111


Fold  3: 2-Methyltetrahydrofuran [2-MeTHF]             MSE = 0.004584


Fold  4: Acetonitrile                                  MSE = 0.012244


Fold  5: Acetonitrile.Acetic Acid                      MSE = 0.028923


Fold  6: Butanone [MEK]                                MSE = 0.028776


Fold  7: Cyclohexane                                   MSE = 0.008340


Fold  8: DMA [N,N-Dimethylacetamide]                   MSE = 0.010872


Fold  9: Decanol                                       MSE = 0.014774


Fold 10: Diethyl Ether [Ether]                         MSE = 0.014629


Fold 11: Dihydrolevoglucosenone (Cyrene)               MSE = 0.007376


Fold 12: Dimethyl Carbonate                            MSE = 0.035863


Fold 13: Ethanol                                       MSE = 0.010598


Fold 14: Ethyl Acetate                                 MSE = 0.002462


Fold 15: Ethyl Lactate                                 MSE = 0.006037


Fold 16: Ethylene Glycol [1,2-Ethanediol]              MSE = 0.014251


Fold 17: IPA [Propan-2-ol]                             MSE = 0.012698


Fold 18: MTBE [tert-Butylmethylether]                  MSE = 0.016111


Fold 19: Methanol                                      MSE = 0.013715


Fold 20: Methyl Propionate                             MSE = 0.010011


Fold 21: THF [Tetrahydrofuran]                         MSE = 0.004474


Fold 22: Water.2,2,2-Trifluoroethanol                  MSE = 0.003622


Fold 23: Water.Acetonitrile                            MSE = 0.016660


Fold 24: tert-Butanol [2-Methylpropan-2-ol]            MSE = 0.002605

LISA + REx Single Solvent CV MSE: 0.014455 +/- 0.010858


In [None]:
# Analyze per-solvent errors
print("\n" + "="*60)
print("Per-Solvent Error Analysis")
print("="*60)

sorted_errors = sorted(solvent_errors_lisa.items(), key=lambda x: x[1], reverse=True)

print("\nTop 10 highest error solvents:")
for i, (solvent, mse) in enumerate(sorted_errors[:10]):
    print(f"  {i+1:2d}. {solvent:40s}: {mse:.6f}")

print("\nTop 10 lowest error solvents:")
for i, (solvent, mse) in enumerate(sorted_errors[-10:]):
    print(f"  {i+1:2d}. {solvent:40s}: {mse:.6f}")

# Calculate variance of per-solvent errors
error_variance = np.var(list(solvent_errors_lisa.values()))
print(f"\nVariance of per-solvent errors: {error_variance:.6f}")
print(f"Std dev of per-solvent errors: {np.sqrt(error_variance):.6f}")

In [None]:
# Run CV for mixtures with LISA + REx
print("\n" + "="*60)
print("Running Mixture CV with LISA + REx Model...")
print("="*60)

mix_splits = list(generate_leave_one_ramp_out_splits(X_mix, Y_full))
print(f"Number of folds: {len(mix_splits)}")

mix_errors_lisa = {}
mix_preds_lisa = []
mix_true_lisa = []

for fold_idx, (train_idx, test_idx) in enumerate(mix_splits):
    X_train = X_mix.iloc[train_idx]
    Y_train = Y_full.iloc[train_idx]
    X_test = X_mix.iloc[test_idx]
    Y_test = Y_full.iloc[test_idx]
    
    test_mixture = X_test['SOLVENT NAME'].iloc[0]
    
    # Train model with LISA + REx
    model = LISARExModel(lisa_alpha=0.5, rex_beta=1.0)
    model.fit(X_train, Y_train)
    
    # Predict
    preds = model.predict(X_test)
    
    # Calculate MSE
    mse = np.mean((preds - Y_test.values) ** 2)
    mix_errors_lisa[test_mixture] = mse
    
    mix_preds_lisa.append(preds)
    mix_true_lisa.append(Y_test.values)
    
    print(f"Fold {fold_idx+1:2d}: {test_mixture:55s} MSE = {mse:.6f}")

mix_preds_lisa = np.vstack(mix_preds_lisa)
mix_true_lisa = np.vstack(mix_true_lisa)
mix_mse_lisa = np.mean((mix_preds_lisa - mix_true_lisa) ** 2)
mix_std_lisa = np.std([mix_errors_lisa[s] for s in mix_errors_lisa])

print(f"\nLISA + REx Mixture CV MSE: {mix_mse_lisa:.6f} +/- {mix_std_lisa:.6f}")

In [None]:
# Calculate overall CV score
print("\n" + "="*60)
print("LISA + REx Model Overall Results")
print("="*60)

n_single = len(all_true_lisa)
n_mix = len(mix_true_lisa)
n_total = n_single + n_mix

overall_mse_lisa = (n_single * single_mse_lisa + n_mix * mix_mse_lisa) / n_total

print(f"\nSingle Solvent CV MSE: {single_mse_lisa:.6f} +/- {single_std_lisa:.6f} (n={n_single})")
print(f"Mixture CV MSE: {mix_mse_lisa:.6f} +/- {mix_std_lisa:.6f} (n={n_mix})")
print(f"Overall CV MSE: {overall_mse_lisa:.6f}")

print(f"\nBaseline (exp_030): CV = 0.008298")
print(f"Improvement: {(0.008298 - overall_mse_lisa) / 0.008298 * 100:.1f}%")

if overall_mse_lisa < 0.008298:
    print("\n✓ BETTER than baseline!")
else:
    print("\n✗ WORSE than baseline.")

In [None]:
# Final Summary
print("\n" + "="*60)
print("EXPERIMENT 050 SUMMARY")
print("="*60)

print(f"\nLISA + REx Model:")
print(f"  Single Solvent CV: {single_mse_lisa:.6f}")
print(f"  Mixture CV: {mix_mse_lisa:.6f}")
print(f"  Overall CV: {overall_mse_lisa:.6f}")
print(f"  vs Baseline: {(overall_mse_lisa - 0.008298) / 0.008298 * 100:+.1f}%")

print(f"\nPer-solvent error variance: {error_variance:.6f}")
print(f"Per-solvent error std: {np.sqrt(error_variance):.6f}")

print("\nKey insights:")
print("1. LISA augmentation creates interpolated samples across solvents")
print("2. REx loss penalizes variance of per-solvent losses")
print("3. Together they encourage domain-invariant features")

if overall_mse_lisa < 0.008298:
    print("\nCONCLUSION: LISA + REx IMPROVES overall CV!")
    print("Consider submitting to test if this changes the CV-LB relationship.")
else:
    print("\nCONCLUSION: LISA + REx does NOT improve overall CV.")
    print("The domain adaptation techniques don't help for this problem.")

print(f"\nRemaining submissions: 5")
print(f"Best model: exp_030 (GP 0.15 + MLP 0.55 + LGBM 0.3) with CV 0.008298, LB 0.0877")