# MOL-MOE for Caco-2 Permeability Prediction

This notebook trains separate MOL-MOE models for:
1. **Caco-2 Permeability Papp A>B** (regression)
2. **Caco-2 Permeability Efflux** (regression)

**Dataset Info:**
- Papp A>B: 2,157 molecules (range: 0-51.41)
- Efflux: 2,161 molecules (range: 0.26-105.64)

**Architecture:**
- 12 experts total (4 SMI-TED + 4 SELFIES-TED + 4 MHG-GNN)
- k=4 experts activated per sample
- Single-task regression approach

## 0. RunPod Setup (Auto-Configuration)

This section automatically configures the environment for RunPod/Jupyter instances:
- Detects platform and GPU availability
- Installs `uv` package manager for fast dependency installation  
- Installs PyTorch 2.2.0 with CUDA 11.8 and all required packages
- Configures paths dynamically (no hardcoded paths)

**First run:** ~3-5 minutes for package installation
**Subsequent runs:** ~10-30 seconds (packages cached by uv)

In [None]:
# ============================================================
# SETUP CELL 1: Environment Detection & Path Configuration
# ============================================================
import os
import sys
import platform
from pathlib import Path

# Detect environment
IS_COLAB = 'google.colab' in sys.modules
IS_JUPYTER = 'ipykernel' in sys.modules
DEVICE_NAME = 'GPU' if os.system('nvidia-smi > /dev/null 2>&1') == 0 else 'CPU'

print("="*60)
print("ENVIRONMENT DETECTION")
print("="*60)
print(f"Platform: {platform.system()}")
print(f"Python: {sys.version.split()[0]}")
print(f"Runtime: {'Colab' if IS_COLAB else 'Jupyter' if IS_JUPYTER else 'Unknown'}")
print(f"Device: {DEVICE_NAME}")
print("="*60)

# Define base paths using pathlib for cross-platform compatibility
NOTEBOOK_DIR = Path.cwd()
MATERIALS_ROOT = NOTEBOOK_DIR.parent.parent.parent  # From notebooks/ up to materials/
MOL_MOE_ROOT = NOTEBOOK_DIR.parent  # models/mol_moe/
EXPERTS_DIR = MOL_MOE_ROOT / "experts"
MOE_DIR = MOL_MOE_ROOT / "moe"
DATA_DIR = MATERIALS_ROOT  # CSVs are at materials/ level

print(f"\nPath Configuration:")
print(f"  Notebook directory: {NOTEBOOK_DIR}")
print(f"  Materials root: {MATERIALS_ROOT}")
print(f"  MoE root: {MOL_MOE_ROOT}")
print(f"  Data directory: {DATA_DIR}")

# Verify critical paths exist
assert MATERIALS_ROOT.exists(), f"Materials root not found: {MATERIALS_ROOT}"
assert EXPERTS_DIR.exists(), f"Experts directory not found: {EXPERTS_DIR}"
assert MOE_DIR.exists(), f"MoE directory not found: {MOE_DIR}"
assert DATA_DIR.exists(), f"Data directory not found: {DATA_DIR}"

print("\n✓ All critical paths verified")

In [None]:
# ============================================================
# SETUP CELL 2: UV Installation & Verification
# ============================================================
import subprocess
import shutil

def check_uv():
    """Check if uv is installed and install if necessary"""
    uv_path = shutil.which('uv')
    if uv_path:
        result = subprocess.run(['uv', '--version'], capture_output=True, text=True)
        print(f"✓ uv found: {result.stdout.strip()}")
        return True
    return False

if not check_uv():
    print("Installing uv...")
    # Install uv using the official installer
    subprocess.run([
        sys.executable, '-m', 'pip', 'install', '--quiet', 'uv'
    ], check=True)
    
    if check_uv():
        print("✓ uv installed successfully")
    else:
        raise RuntimeError("Failed to install uv")
else:
    print("uv already installed")

In [None]:
# ============================================================
# SETUP CELL 2.5: Install System Dependencies (Linux only)
# ============================================================
import platform

if platform.system() == 'Linux':
    print("Installing system dependencies for RDKit...")
    try:
        # Install X11 libraries needed for RDKit rendering
        result = subprocess.run([
            'apt-get', 'update', '-qq'
        ], capture_output=True, text=True)
        
        subprocess.run([
            'apt-get', 'install', '-y', '-qq',
            'libxrender1',
            'libxext6',
            'libsm6',
            'libfontconfig1'
        ], check=True, capture_output=True, text=True)
        
        print("✓ System dependencies installed")
    except subprocess.CalledProcessError as e:
        print("⚠️  Warning: Could not install system packages (may need sudo)")
        print("   RDKit rendering may not work, but training will still function")
    except FileNotFoundError:
        print("⚠️  Warning: apt-get not found (not a Debian/Ubuntu system)")
        print("   RDKit rendering may not work, but training will still function")
else:
    print(f"Platform: {platform.system()} - skipping Linux system dependencies")

In [None]:
# ============================================================
# SETUP CELL 3: Install Dependencies with UV
# ============================================================

print("Installing dependencies with uv (this may take a few minutes on first run)...")
print("Configuration:")
print("  - Python: 3.10+")
print("  - PyTorch: 2.2.0 with CUDA 11.8")
print("  - Installing to: system environment")
print()

try:
    # Step 1: Install PyTorch with CUDA 11.8
    print("[1/3] Installing PyTorch with CUDA 11.8...")
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
        '--index-url', 'https://download.pytorch.org/whl/cu118',
        'torch==2.2.0',
        'torchvision==0.17.0', 
        'torchaudio==2.2.0'
    ], check=True, capture_output=True)
    print("      ✓ PyTorch 2.2.0 with CUDA 11.8 installed")
    
    # Step 2: Install torch-scatter with special index
    print("[2/3] Installing torch-scatter...")
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
        '--find-links', 'https://data.pyg.org/whl/torch-2.2.0+cu118.html',
        'torch-scatter'
    ], check=True, capture_output=True)
    print("      ✓ torch-scatter installed")
    
    # Step 3: Install remaining dependencies
    print("[3/3] Installing remaining dependencies...")
    remaining_deps = [
        'torch-geometric>=2.3.1',
        'matplotlib==3.9.2',
        'numpy>=1.26.1,<2.0.0',
        'pandas>=1.5.3',
        'scikit-learn>=1.5.0',
        'rdkit>=2024.3.5',
        'datasets>=2.13.1',
        'huggingface-hub',
        'transformers>=4.38',
        'selfies>=2.1.0',
        'tqdm>=4.66.4',
        'xgboost==2.0.0',
        'seaborn',  # For plotting
    ]
    
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
    ] + remaining_deps, check=True, capture_output=True)
    print("      ✓ All dependencies installed successfully")
    
except subprocess.CalledProcessError as e:
    print(f"\n❌ Installation failed: {e}")
    print("Trying to show error output:")
    if e.stderr:
        print(e.stderr.decode())
    raise

# Verify installations
import torch
print(f"\nVerification:")
print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    
print("\n" + "="*60)
print("✓ Environment setup complete!")
print("="*60)

In [None]:
# ============================================================
# SETUP CELL 4: Configure Module Import Paths
# ============================================================

# Add project directories to Python path
# Use absolute paths to avoid issues with working directory changes
sys.path.insert(0, str(MOL_MOE_ROOT))  # For 'from moe import ...'
sys.path.insert(0, str(EXPERTS_DIR))   # For 'from mhg_model import ...' (needed for unpickling)
sys.path.insert(0, str(MOE_DIR))       # For 'from models import ...'

print("Module search paths configured:")
for i, path in enumerate(sys.path[:6]):
    print(f"  {i}: {path}")

# Verify imports work
try:
    from moe import MoE
    print("\n✓ MoE module importable")
except ImportError as e:
    print(f"\n✗ MoE import failed: {e}")
    print("  Ensure you're running from materials/models/mol_moe/notebooks/")
    
try:
    from models import Net
    print("✓ Net model importable")
except ImportError as e:
    print(f"✗ Net import failed: {e}")

# Test that mhg_model can be imported (needed for unpickling)
try:
    import mhg_model
    print("✓ mhg_model module importable (needed for model loading)")
except ImportError as e:
    print(f"⚠️  mhg_model not directly importable: {e}")
    print("   This may cause issues when loading pre-trained MHG-GNN model")

print("\n✓ Module paths configured successfully")
print("\n" + "="*60)
print("SETUP COMPLETE - Ready to proceed with training!")
print("="*60)

## 1. Setup and Imports

In [None]:
# System
import warnings
warnings.filterwarnings("ignore")

# Deep learning
import torch
import torch.nn.functional as F
from torch import nn
from moe import MoE, train
from models import Net

# Machine learning
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Chemistry
from rdkit import Chem
from rdkit.Chem import Descriptors

# Try to enable PandasTools rendering (may fail on headless systems)
try:
    from rdkit.Chem import PandasTools
    PandasTools.RenderImagesInAllDataFrames(True)
    print("✓ RDKit rendering enabled")
except ImportError as e:
    print(f"⚠️  RDKit rendering disabled (missing system libraries)")
    print("   Training will work normally, but molecule images won't display in DataFrames")
    PandasTools = None

def normalize_smiles(smi, canonical=True, isomeric=False):
    try:
        normalized = Chem.MolToSmiles(
            Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
        )
    except:
        normalized = None
    return normalized

torch.manual_seed(42)
np.random.seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

## 2. Load Foundation Models

In [None]:
from experts.selfies_ted.load import SELFIES

print("Loading SELFIES-TED...")
model_selfies = SELFIES()
model_selfies.load()

# Fix 1: Patch get_embedding to move tensors to device
original_get_embedding = model_selfies.get_embedding

def patched_get_embedding(self, selfies):
    """Patched get_embedding that ensures tensors are on the correct device"""
    encoding = self.tokenizer(selfies['selfies'], return_tensors='pt', padding=True, truncation=True)
    
    # Move to same device as the model
    device = next(self.model.parameters()).device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    outputs = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
    model_output = outputs.last_hidden_state
    
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
    sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    mean_embeddings = sum_embeddings / sum_mask
    
    # Move back to CPU for dataset processing
    embeddings = {'embedding': mean_embeddings.detach().cpu().numpy().tolist()}
    return embeddings

model_selfies.get_embedding = lambda selfies: patched_get_embedding(model_selfies, selfies)

# Fix 2: Patch encode to disable multiprocessing
original_encode = model_selfies.encode

def patched_encode(self, smiles_list, use_gpu=False, return_tensor=True):
    """Patched encode that disables multiprocessing and ensures device compatibility"""
    import pandas as pd
    import numpy as np
    from datasets import Dataset
    
    # Convert to SELFIES
    selfies = []
    self.invalid = []
    for i, smile in enumerate(smiles_list):
        try:
            selfies.append(self.encoder(smile))
        except:
            selfies.append("")
            self.invalid.append(i)
    
    selfies_df = pd.DataFrame(selfies, columns=["selfies"])
    data = Dataset.from_pandas(selfies_df)
    
    # CRITICAL: num_proc=None to disable multiprocessing (prevents CUDA fork issues)
    embedding = data.map(self.get_embedding, batched=True, num_proc=None, batch_size=128)
    
    # Convert Column to numpy array (datasets returns Column objects)
    emb = np.array(list(embedding["embedding"]))
    
    for idx in self.invalid:
        emb = np.insert(emb, idx, np.zeros(emb.shape[1]), axis=0)
    
    if return_tensor:
        return torch.tensor(emb)
    return emb

model_selfies.encode = lambda smiles_list, use_gpu=False, return_tensor=True: patched_encode(
    model_selfies, smiles_list, use_gpu, return_tensor
)

print("✓ SELFIES-TED loaded")
print("  - Multiprocessing disabled (CUDA compatibility)")
print("  - Device handling patched")

In [None]:
from experts.mhg_model.load import load

print("Loading MHG-GNN...")
mhg_gnn = load()
print("✓ MHG-GNN loaded")

In [None]:
from experts.smi_ted_light.load import load_smi_ted, MolTranBertTokenizer

print("Loading SMI-TED...")
smi_ted = load_smi_ted()
print("✓ SMI-TED loaded")

## 3. Select Target Endpoint

**Choose which endpoint to train:**
- `papp_ab`: Caco-2 Permeability Papp A>B
- `efflux`: Caco-2 Permeability Efflux

In [None]:
# ============================================
# SELECT YOUR ENDPOINT HERE
# ============================================
ENDPOINT = 'papp_ab'  # Options: 'papp_ab' or 'efflux'
# ============================================

# Use dynamic paths computed in setup cells
if ENDPOINT == 'papp_ab':
    data_file = DATA_DIR / 'train_Caco2_Permeability_Papp_AB.csv'
    target_col = 'Caco-2 Permeability Papp A>B'
    model_name = 'Caco2_Papp_AB'
elif ENDPOINT == 'efflux':
    data_file = DATA_DIR / 'train_Caco2_Permeability_Efflux.csv'
    target_col = 'Caco-2 Permeability Efflux'
    model_name = 'Caco2_Efflux'
else:
    raise ValueError("ENDPOINT must be 'papp_ab' or 'efflux'")

# Verify data file exists
assert data_file.exists(), f"Data file not found: {data_file}"

print(f"Selected endpoint: {ENDPOINT}")
print(f"Target column: {target_col}")
print(f"Model name: {model_name}")
print(f"Data file: {data_file}")

## 4. Load and Prepare Data

In [None]:
# Load data
df = pd.read_csv(data_file)
print(f"Original dataset shape: {df.shape}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Normalize SMILES
print("Normalizing SMILES...")
df['canon_smiles'] = df['SMILES'].apply(normalize_smiles)

# Remove invalid SMILES
original_count = len(df)
df = df.dropna(subset=['canon_smiles', target_col])
print(f"Removed {original_count - len(df)} invalid entries")
print(f"Final dataset shape: {df.shape}")

# Show target statistics
print(f"\n{target_col} statistics:")
print(df[target_col].describe())

In [None]:
# Visualize target distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
axes[0].hist(df[target_col], bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel(target_col)
axes[0].set_ylabel('Frequency')
axes[0].set_title(f'Distribution of {target_col}')
axes[0].grid(alpha=0.3)

# Box plot
axes[1].boxplot(df[target_col])
axes[1].set_ylabel(target_col)
axes[1].set_title(f'Box Plot of {target_col}')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Split data: 70% train, 15% validation, 15% test
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
valid_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(valid_df)} samples")
print(f"Test set: {len(test_df)} samples")

# Prepare data
smiles_col = 'canon_smiles'

X_train = train_df[smiles_col].to_list()
y_train = torch.tensor(train_df[target_col].values, dtype=torch.float32)

X_valid = valid_df[smiles_col].to_list()
y_valid = torch.tensor(valid_df[target_col].values, dtype=torch.float32)

X_test = test_df[smiles_col].to_list()
y_test = torch.tensor(test_df[target_col].values, dtype=torch.float32)

## 5. Configure and Train MOE Model

In [None]:
# ============================================
# HYPERPARAMETERS - Adjust as needed
# ============================================

# Model architecture
input_size = 768          # Embedding dimension (fixed)
output_size = 2048        # Output dimension
num_experts = 12          # Total number of experts
k = 4                     # Number of experts to activate

# Training settings
batch_size = 32           # Increase if you have GPU memory
learning_rate = 1e-4      # Lower for regression
epochs = 150              # Train for more epochs

# Regression output
output_dim = 1            # Single target regression
dropout = 0.2             # Dropout rate

print("Configuration:")
print(f"  - Batch size: {batch_size}")
print(f"  - Learning rate: {learning_rate}")
print(f"  - Epochs: {epochs}")
print(f"  - Experts: {num_experts}, activating {k} per sample")
print(f"  - Device: {DEVICE}")

In [None]:
# Define experts (4 per modality)
models = [
    smi_ted, smi_ted, smi_ted, smi_ted,              # SMI-TED experts
    model_selfies, model_selfies, model_selfies, model_selfies,  # SELFIES-TED experts
    mhg_gnn, mhg_gnn, mhg_gnn, mhg_gnn                # MHG-GNN experts
]

# Initialize tokenizer with dynamic path
vocab_path = EXPERTS_DIR / 'smi_ted_light' / 'bert_vocab_curated.txt'
assert vocab_path.exists(), f"Vocab file not found: {vocab_path}"
tokenizer = MolTranBertTokenizer(str(vocab_path))

# Initialize MOE
print("Initializing MOE model...")
moe_model = MoE(
    input_size=input_size, 
    output_size=output_size, 
    num_experts=num_experts, 
    models=models, 
    tokenizer=tokenizer, 
    tok_emb=smi_ted.encoder.tok_emb, 
    k=k, 
    noisy_gating=True,      # Use noisy gating for better exploration
    verbose=False
).to(DEVICE)

# Fix 1: Patch the EmbeddingNet to move tokens to device
# This fixes the "Expected all tensors to be on the same device" error
def fixed_forward(self, smiles):
    """Patched forward that ensures tokens are on the correct device"""
    tokens = self.tokenizer(smiles, padding=True, truncation=True, max_length=512, return_tensors='pt')
    
    # Move tokens to the same device as the model
    device = next(self.parameters()).device
    idx = tokens['input_ids'].to(device)
    mask = tokens['attention_mask'].to(device)
    
    token_embeddings = self.tok_emb(idx)
    input_mask_expanded = mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

# Apply the embedding patch
moe_model.embd_net.forward = lambda smiles: fixed_forward(moe_model.embd_net, smiles)

# Fix 2: Patch SparseDispatcher.dispatch to use pure Python list operations
# This fixes pandas/numpy indexing issues with CUDA tensors
from moe import SparseDispatcher

# Only patch if not already patched (prevents recursion if cell is run multiple times)
if not hasattr(SparseDispatcher.dispatch, '_is_patched'):
    _original_dispatch = SparseDispatcher.dispatch

    def patched_dispatch(self, inp):
        """Patched dispatch using pure Python list operations (no pandas/numpy)"""
        # Convert batch_index to Python list
        if self._batch_index.is_cuda:
            batch_index_list = self._batch_index.cpu().tolist()
        else:
            batch_index_list = self._batch_index.tolist()
        
        # Index into input list
        inp_expanded = [inp[i] for i in batch_index_list]
        
        # Split into parts according to _part_sizes
        result = []
        start_idx = 0
        for part_size in self._part_sizes:
            result.append(inp_expanded[start_idx:start_idx + part_size])
            start_idx += part_size
        
        return result

    # Mark as patched and replace
    patched_dispatch._is_patched = True
    SparseDispatcher.dispatch = patched_dispatch
    print("✓ SparseDispatcher.dispatch patched (pure Python list ops)")
else:
    print("✓ SparseDispatcher.dispatch already patched")

# Initialize predictor network
net = Net(smiles_embed_dim=output_size, dropout=dropout, output_dim=output_dim)
net.apply(smi_ted._init_weights)
net = net.to(DEVICE)

print("✓ Models initialized")
print("✓ Device compatibility patches applied (EmbeddingNet + SparseDispatcher)")

In [None]:
# Loss function and optimizer
loss_fn = nn.MSELoss()  # Mean Squared Error for regression

params = list(moe_model.parameters()) + list(net.parameters())
optim = torch.optim.AdamW(params, lr=learning_rate, weight_decay=1e-5)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode='min', factor=0.5, patience=10, verbose=True
)

# Data loader
train_loader = torch.utils.data.DataLoader(
    list(zip(X_train, y_train)), 
    batch_size=batch_size,
    shuffle=True, 
    num_workers=0
)

print(f"Training batches per epoch: {len(train_loader)}")

In [None]:
# ============================================================
# PRE-TRAINING VALIDATION
# ============================================================

print("Running pre-training validation checks...\n")

# 1. GPU Check
assert torch.cuda.is_available(), "❌ CUDA not available!"
print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"  Memory: {gpu_mem_gb:.1f} GB")

# 2. Data Check
assert data_file.exists(), f"❌ Data file not found: {data_file}"
df_test = pd.read_csv(data_file, nrows=5)
assert target_col in df_test.columns, f"❌ Target column '{target_col}' not found"
full_dataset_size = len(pd.read_csv(data_file))
print(f"✓ Data file valid: {full_dataset_size} samples")

# 3. Model Check
assert moe_model is not None, "❌ MoE model not initialized"
assert net is not None, "❌ Net model not initialized"
print(f"✓ Models initialized and on device: {DEVICE}")

# 4. Disk Space Check
import shutil
disk_usage = shutil.disk_usage('.')
free_gb = disk_usage.free / 1e9
assert free_gb > 10, f"❌ Low disk space: {free_gb:.1f} GB"
print(f"✓ Disk space: {free_gb:.1f} GB free")

# 5. Memory Check (estimate)
model_params = sum(p.numel() for p in moe_model.parameters()) + sum(p.numel() for p in net.parameters())
print(f"✓ Model parameters: {model_params/1e6:.1f}M")

# 6. Create checkpoint directory
CHECKPOINT_DIR = NOTEBOOK_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(exist_ok=True)
print(f"✓ Checkpoint directory: {CHECKPOINT_DIR}")

print("\n" + "="*60)
print("All validation checks passed! Ready to train.")
print("="*60)

In [None]:
# Custom training loop with validation
from tqdm import tqdm

train_losses = []
valid_losses = []
best_valid_loss = float('inf')

print("Starting training...\n")

for epoch in range(epochs):
    # Training
    moe_model.train()
    net.train()
    epoch_loss = 0
    
    for (x, y) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        y = y.to(DEVICE)
        
        optim.zero_grad()
        
        # Forward pass
        embd, aux_loss = moe_model(x)
        y_hat = net(embd).squeeze()
        
        # Calculate loss
        loss = loss_fn(y_hat, y)
        total_loss = loss + aux_loss
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)  # Gradient clipping
        optim.step()
        
        epoch_loss += loss.item()
    
    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation
    moe_model.eval()
    net.eval()
    with torch.no_grad():
        valid_embd, _ = moe_model(X_valid, verbose=False)
        valid_preds = net(valid_embd).squeeze()
        valid_loss = loss_fn(valid_preds.cpu(), y_valid).item()
    
    valid_losses.append(valid_loss)
    scheduler.step(valid_loss)
    
    # Save best model to checkpoint directory
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint_path = CHECKPOINT_DIR / f'best_{model_name}_moe_model.pt'
        torch.save({
            'epoch': epoch,
            'moe_state_dict': moe_model.state_dict(),
            'net_state_dict': net.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'valid_loss': valid_loss,
        }, checkpoint_path)
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {valid_loss:.4f}")
        print(f"  Best Valid Loss: {best_valid_loss:.4f}\n")

print("\n✓ Training completed!")
print(f"Best validation loss: {best_valid_loss:.4f}")
print(f"Model saved to: {CHECKPOINT_DIR / f'best_{model_name}_moe_model.pt'}")

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', alpha=0.8)
plt.plot(valid_losses, label='Validation Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title(f'Training History - {model_name}')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

## 6. Evaluate on Test Set

In [None]:
# Load best model from checkpoint directory
checkpoint_path = CHECKPOINT_DIR / f'best_{model_name}_moe_model.pt'
checkpoint = torch.load(checkpoint_path)
moe_model.load_state_dict(checkpoint['moe_state_dict'])
net.load_state_dict(checkpoint['net_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Checkpoint: {checkpoint_path}")

In [None]:
# Evaluate on test set
moe_model.eval()
net.eval()

with torch.no_grad():
    test_embd, _ = moe_model(X_test, verbose=False)
    test_preds = net(test_embd).squeeze()
    test_preds_np = test_preds.cpu().numpy()
    y_test_np = y_test.numpy()

# Calculate metrics
rmse = np.sqrt(mean_squared_error(y_test_np, test_preds_np))
mae = mean_absolute_error(y_test_np, test_preds_np)
r2 = r2_score(y_test_np, test_preds_np)

print("\n" + "="*50)
print(f"TEST SET RESULTS - {model_name}")
print("="*50)
print(f"RMSE: {rmse:.4f}")
print(f"MAE:  {mae:.4f}")
print(f"R²:   {r2:.4f}")
print("="*50)

In [None]:
# Parity plot
fig, ax = plt.subplots(figsize=(8, 8))

# Scatter plot
ax.scatter(y_test_np, test_preds_np, alpha=0.6, edgecolors='black', linewidth=0.5)

# Perfect prediction line
min_val = min(y_test_np.min(), test_preds_np.min())
max_val = max(y_test_np.max(), test_preds_np.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect prediction')

# Labels and title
ax.set_xlabel(f'Actual {target_col}', fontsize=12)
ax.set_ylabel(f'Predicted {target_col}', fontsize=12)
ax.set_title(f'Parity Plot - {model_name}\nRMSE={rmse:.3f}, R²={r2:.3f}', fontsize=14)
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Residuals plot
residuals = y_test_np - test_preds_np

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Residuals vs predicted
axes[0].scatter(test_preds_np, residuals, alpha=0.6, edgecolors='black', linewidth=0.5)
axes[0].axhline(y=0, color='r', linestyle='--', linewidth=2)
axes[0].set_xlabel('Predicted Values', fontsize=12)
axes[0].set_ylabel('Residuals', fontsize=12)
axes[0].set_title('Residuals vs Predicted', fontsize=12)
axes[0].grid(alpha=0.3)

# Residuals distribution
axes[1].hist(residuals, bins=30, edgecolor='black', alpha=0.7)
axes[1].axvline(x=0, color='r', linestyle='--', linewidth=2)
axes[1].set_xlabel('Residuals', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title('Distribution of Residuals', fontsize=12)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Mean of residuals: {residuals.mean():.4f}")
print(f"Std of residuals: {residuals.std():.4f}")

## 7. Train XGBoost on MOE Embeddings

In [None]:
# Extract embeddings for XGBoost
print("Extracting embeddings...")
moe_model.eval()

with torch.no_grad():
    xgb_train, _ = moe_model(X_train, verbose=False)
    xgb_valid, _ = moe_model(X_valid, verbose=False)
    xgb_test, _ = moe_model(X_test, verbose=False)

xgb_train = xgb_train.cpu().numpy()
xgb_valid = xgb_valid.cpu().numpy()
xgb_test = xgb_test.cpu().numpy()

y_train_np = y_train.numpy()
y_valid_np = y_valid.numpy()

print(f"Train embeddings shape: {xgb_train.shape}")
print(f"Valid embeddings shape: {xgb_valid.shape}")
print(f"Test embeddings shape: {xgb_test.shape}")

In [None]:
# Train XGBoost
print("Training XGBoost...")
xgb_model = XGBRegressor(
    n_estimators=2000,
    learning_rate=0.05,
    max_depth=8,
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
    early_stopping_rounds=50,
    eval_metric='rmse'
)

xgb_model.fit(
    xgb_train, y_train_np,
    eval_set=[(xgb_valid, y_valid_np)],
    verbose=100
)

print("\n✓ XGBoost training completed")

In [None]:
# Evaluate XGBoost
xgb_preds = xgb_model.predict(xgb_test)

xgb_rmse = np.sqrt(mean_squared_error(y_test_np, xgb_preds))
xgb_mae = mean_absolute_error(y_test_np, xgb_preds)
xgb_r2 = r2_score(y_test_np, xgb_preds)

print("\n" + "="*50)
print(f"XGBoost TEST SET RESULTS - {model_name}")
print("="*50)
print(f"RMSE: {xgb_rmse:.4f}")
print(f"MAE:  {xgb_mae:.4f}")
print(f"R²:   {xgb_r2:.4f}")
print("="*50)

# Save XGBoost model to checkpoint directory
xgb_model_path = CHECKPOINT_DIR / f'xgboost_{model_name}_model.json'
xgb_model.save_model(str(xgb_model_path))
print(f"\n✓ XGBoost model saved to: {xgb_model_path}")

In [None]:
# XGBoost Parity plot
fig, ax = plt.subplots(figsize=(8, 8))

ax.scatter(y_test_np, xgb_preds, alpha=0.6, edgecolors='black', linewidth=0.5)
ax.plot([y_test_np.min(), y_test_np.max()], [y_test_np.min(), y_test_np.max()], 
        'r--', linewidth=2, label='Perfect prediction')

ax.set_xlabel(f'Actual {target_col}', fontsize=12)
ax.set_ylabel(f'Predicted {target_col}', fontsize=12)
ax.set_title(f'XGBoost Parity Plot - {model_name}\nRMSE={xgb_rmse:.3f}, R²={xgb_r2:.3f}', fontsize=14)
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Model Comparison Summary

In [None]:
# Compare both approaches
comparison = pd.DataFrame({
    'Model': ['MOE + Net', 'MOE + XGBoost'],
    'RMSE': [rmse, xgb_rmse],
    'MAE': [mae, xgb_mae],
    'R²': [r2, xgb_r2]
})

print("\n" + "="*60)
print(f"FINAL COMPARISON - {model_name}")
print("="*60)
print(comparison.to_string(index=False))
print("="*60)

# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

metrics = ['RMSE', 'MAE', 'R²']
for idx, metric in enumerate(metrics):
    axes[idx].bar(comparison['Model'], comparison[metric], 
                  color=['steelblue', 'coral'], edgecolor='black', alpha=0.7)
    axes[idx].set_ylabel(metric, fontsize=12)
    axes[idx].set_title(f'{metric} Comparison', fontsize=12)
    axes[idx].grid(alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, v in enumerate(comparison[metric]):
        axes[idx].text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

# Save comparison to checkpoint directory
comparison_path = CHECKPOINT_DIR / f'comparison_{model_name}.csv'
comparison.to_csv(comparison_path, index=False)
print(f"\n✓ Comparison saved to: {comparison_path}")

## 9. Make Predictions on New Molecules

In [None]:
# Example: Predict on new SMILES
new_smiles = [
    'CCO',  # Ethanol
    'CC(=O)O',  # Acetic acid
    'c1ccccc1',  # Benzene
]

# Normalize
new_smiles_canon = [normalize_smiles(s) for s in new_smiles]

# Predict using MOE+Net
moe_model.eval()
net.eval()
with torch.no_grad():
    new_embd, _ = moe_model(new_smiles_canon, verbose=False)
    new_preds = net(new_embd).squeeze().cpu().numpy()

# Predict using XGBoost
xgb_new_preds = xgb_model.predict(new_embd.cpu().numpy())

# Display results
results_df = pd.DataFrame({
    'SMILES': new_smiles,
    'MOE+Net Prediction': new_preds,
    'XGBoost Prediction': xgb_new_preds
})

print(f"\nPredictions for {target_col}:")
print(results_df.to_string(index=False))