# MoL-MoE Production Template - Custom Regression Tasks

This notebook trains MoL-MoE (Mixture of Experts) models on custom CSV datasets.

**Features:**
- Auto-detects environment (Jupyter/Colab/RunPod)
- Fast dependency installation with `uv`
- Works with any CSV file (just configure SMILES and target columns)
- Includes device-aware training (no more GPU/CPU mismatch errors!)
- Trains both MoE+Net and MoE+XGBoost models

**Architecture:**
- 12 experts: 4x SMI-TED + 4x SELFIES-TED + 4x MHG-GNN
- k=4 experts activated per sample
- Suitable for regression tasks

## 1. Environment Setup

### Environment Detection & Path Configuration

In [None]:
# ============================================================
# SETUP CELL 1: Environment Detection & Path Configuration
# ============================================================
import os
import sys
import platform
from pathlib import Path

# Detect environment
IS_COLAB = 'google.colab' in sys.modules
IS_JUPYTER = 'ipykernel' in sys.modules
DEVICE_NAME = 'GPU' if os.system('nvidia-smi > /dev/null 2>&1') == 0 else 'CPU'

print("="*60)
print("ENVIRONMENT DETECTION")
print("="*60)
print(f"Platform: {platform.system()}")
print(f"Python: {sys.version.split()[0]}")
print(f"Runtime: {'Colab' if IS_COLAB else 'Jupyter' if IS_JUPYTER else 'Unknown'}")
print(f"Device: {DEVICE_NAME}")
print("="*60)

# Define base paths using pathlib for cross-platform compatibility
NOTEBOOK_DIR = Path.cwd()
MOL_MOE_ROOT = NOTEBOOK_DIR.parent  # models/mol_moe/
EXPERTS_DIR = MOL_MOE_ROOT / "experts"
MOE_DIR = MOL_MOE_ROOT / "moe"
MATERIALS_ROOT = MOL_MOE_ROOT.parent.parent  # Up to materials/
DATA_DIR = MATERIALS_ROOT  # CSVs can be at materials/ level

print(f"\nPath Configuration:")
print(f"  Notebook directory: {NOTEBOOK_DIR}")
print(f"  MoE root: {MOL_MOE_ROOT}")
print(f"  Data directory: {DATA_DIR}")

# Verify critical paths exist
assert EXPERTS_DIR.exists(), f"Experts directory not found: {EXPERTS_DIR}"
assert MOE_DIR.exists(), f"MoE directory not found: {MOE_DIR}"

print("\n‚úì All critical paths verified")

### UV Installation & Verification

In [None]:
# ============================================================
# SETUP CELL 2: UV Installation & Verification
# ============================================================
import subprocess
import shutil

def check_uv():
    """Check if uv is installed and install if necessary"""
    uv_path = shutil.which('uv')
    if uv_path:
        result = subprocess.run(['uv', '--version'], capture_output=True, text=True)
        print(f"‚úì uv found: {result.stdout.strip()}")
        return True
    return False

if not check_uv():
    print("Installing uv...")
    subprocess.run([
        sys.executable, '-m', 'pip', 'install', '--quiet', 'uv'
    ], check=True)
    
    if check_uv():
        print("‚úì uv installed successfully")
    else:
        raise RuntimeError("Failed to install uv")
else:
    print("uv already installed")

### System Dependencies (Linux only)

In [None]:
# ============================================================
# SETUP CELL 3: Install System Dependencies (Linux only)
# ============================================================
import platform

if platform.system() == 'Linux':
    print("Installing system dependencies for RDKit...")
    try:
        subprocess.run(['apt-get', 'update', '-qq'], capture_output=True, text=True)
        subprocess.run([
            'apt-get', 'install', '-y', '-qq',
            'libxrender1', 'libxext6', 'libsm6', 'libfontconfig1'
        ], check=True, capture_output=True, text=True)
        print("‚úì System dependencies installed")
    except subprocess.CalledProcessError:
        print("‚ö†Ô∏è  Warning: Could not install system packages (may need sudo)")
        print("   RDKit rendering may not work, but training will still function")
    except FileNotFoundError:
        print("‚ö†Ô∏è  Warning: apt-get not found (not a Debian/Ubuntu system)")
else:
    print(f"Platform: {platform.system()} - skipping Linux system dependencies")

### Install Dependencies with UV

In [None]:
# ============================================================
# SETUP CELL 4: Install Dependencies with UV
# ============================================================

print("Installing dependencies with uv (first run: ~3-5 min, cached: ~30 sec)...")
print("Configuration:")
print("  - Python: 3.10+")
print("  - PyTorch: 2.2.0 with CUDA 11.8")
print()

try:
    # Step 1: Install PyTorch with CUDA 11.8
    print("[1/3] Installing PyTorch with CUDA 11.8...")
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
        '--index-url', 'https://download.pytorch.org/whl/cu118',
        'torch==2.2.0',
        'torchvision==0.17.0', 
        'torchaudio==2.2.0'
    ], check=True, capture_output=True)
    print("      ‚úì PyTorch 2.2.0 with CUDA 11.8 installed")
    
    # Step 2: Install torch-scatter
    print("[2/3] Installing torch-scatter...")
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
        '--find-links', 'https://data.pyg.org/whl/torch-2.2.0+cu118.html',
        'torch-scatter'
    ], check=True, capture_output=True)
    print("      ‚úì torch-scatter installed")
    
    # Step 3: Install remaining dependencies
    print("[3/3] Installing remaining dependencies...")
    remaining_deps = [
        'torch-geometric>=2.3.1',
        'matplotlib==3.9.2',
        'numpy==1.26.4',  # Pin to 1.26.4 to avoid numpy 2.x breaking changes
        'pandas>=1.5.3',
        'scikit-learn>=1.5.0',
        'rdkit>=2024.3.5',
        'datasets>=2.13.1',
        'huggingface-hub',
        'transformers==4.44.2',  # Pinned version compatible with BART models
        'selfies>=2.1.0',
        'tqdm>=4.66.4',
        'xgboost==2.1.3',  # Updated version compatible with numpy 1.26.x
        'seaborn',
    ]
    
    subprocess.run([
        'uv', 'pip', 'install',
        '--python', sys.executable,
    ] + remaining_deps, check=True, capture_output=True)
    print("      ‚úì All dependencies installed successfully")
    
except subprocess.CalledProcessError as e:
    print(f"\n‚ùå Installation failed: {e}")
    raise

# Verify installations
import torch
print(f"\nVerification:")
print(f"  PyTorch version: {torch.__version__}")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    
print("\n" + "="*60)
print("‚úì Environment setup complete!")
print("="*60)


### Configure Module Import Paths

In [None]:
# ============================================================
# SETUP CELL 5: Configure Module Import Paths
# ============================================================

# Add project directories to Python path
sys.path.insert(0, str(MOL_MOE_ROOT))
sys.path.insert(0, str(EXPERTS_DIR))
sys.path.insert(0, str(MOE_DIR))

print("Module search paths configured:")
for i, path in enumerate(sys.path[:6]):
    print(f"  {i}: {path}")

# Verify imports work
try:
    from moe import MoE
    print("\n‚úì MoE module importable")
except ImportError as e:
    print(f"\n‚úó MoE import failed: {e}")
    
try:
    from models import Net
    print("‚úì Net model importable")
except ImportError as e:
    print(f"‚úó Net import failed: {e}")

print("\n‚úì Module paths configured successfully")
print("\n" + "="*60)
print("SETUP COMPLETE - Ready to proceed with training!")
print("="*60)

## 2. Imports & Helper Functions

In [None]:
# System
import warnings
warnings.filterwarnings("ignore")

# Deep learning
import torch
import torch.nn.functional as F
from torch import nn
from moe import MoE, train
from models import Net

# Machine learning
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Data
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Chemistry
from rdkit import Chem
from rdkit.Chem import Descriptors

# Try to enable PandasTools rendering
try:
    from rdkit.Chem import PandasTools
    PandasTools.RenderImagesInAllDataFrames(True)
    print("‚úì RDKit rendering enabled")
except:
    print("‚ö†Ô∏è  RDKit rendering disabled (missing system libraries)")
    print("   Training will work normally")

def normalize_smiles(smi, canonical=True, isomeric=False):
    try:
        normalized = Chem.MolToSmiles(
            Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric
        )
    except:
        normalized = None
    return normalized

torch.manual_seed(42)
np.random.seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

## 3. Data Configuration

**‚ö†Ô∏è EDIT THIS SECTION** - Configure your custom CSV file:

In [None]:
# ============================================
# USER CONFIGURATION - EDIT THIS SECTION
# ============================================

# Path to your CSV file
DATA_FILE = DATA_DIR / 'train_Caco2_Permeability_Papp_AB.csv'

# Column names in your CSV
SMILES_COLUMN = 'SMILES'
TARGET_COLUMN = 'Caco-2 Permeability Papp A>B'

# Task configuration
TASK_TYPE = 'regression'  # or 'classification'
MODEL_NAME = 'Caco2_Papp_AB'  # Used for checkpoint naming

# ============================================

# Verify data file exists
assert DATA_FILE.exists(), f"Data file not found: {DATA_FILE}"

print(f"Selected configuration:")
print(f"  Data file: {DATA_FILE}")
print(f"  SMILES column: {SMILES_COLUMN}")
print(f"  Target column: {TARGET_COLUMN}")
print(f"  Task type: {TASK_TYPE}")
print(f"  Model name: {MODEL_NAME}")

## 4. Hyperparameters

**‚ö†Ô∏è EDIT AS NEEDED** - Adjust training parameters:

In [None]:
# ============================================
# HYPERPARAMETERS - ADJUST AS NEEDED
# ============================================

# Model architecture
input_size = 768          # Embedding dimension (fixed)
output_size = 2048        # Output dimension
num_experts = 12          # Total experts
k = 4                     # Experts activated per sample

# Training settings
batch_size = 32           # Reduce if OOM (e.g., 16)
learning_rate = 1e-4      # Learning rate
epochs = 150              # Training epochs
dropout = 0.2             # Dropout rate

# Output
output_dim = 1            # Single target regression

# Data split ratios
train_ratio = 0.70
valid_ratio = 0.15
test_ratio = 0.15

# ============================================

print(f"Configuration:")
print(f"  Batch size: {batch_size}")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {epochs}")
print(f"  Experts: {num_experts}, activating {k} per sample")
print(f"  Device: {DEVICE}")

## 5. Load Foundation Models

In [None]:
from experts.selfies_ted.load import SELFIES

print("Loading SELFIES-TED...")
model_selfies = SELFIES()
model_selfies.load()
print("‚úì SELFIES-TED loaded")

In [None]:
from experts.mhg_model.load import load

print("Loading MHG-GNN...")
mhg_gnn = load()
print("‚úì MHG-GNN loaded")

In [None]:
from experts.smi_ted_light.load import load_smi_ted, MolTranBertTokenizer

print("Loading SMI-TED...")
smi_ted = load_smi_ted()
print("‚úì SMI-TED loaded")

## 6. Load and Prepare Data

In [None]:
# Load data
df = pd.read_csv(DATA_FILE)
print(f"Original dataset shape: {df.shape}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Normalize SMILES
print("Normalizing SMILES...")
df['canon_smiles'] = df[SMILES_COLUMN].apply(normalize_smiles)

# Remove invalid SMILES
original_count = len(df)
df = df.dropna(subset=['canon_smiles', TARGET_COLUMN])
print(f"Removed {original_count - len(df)} invalid entries")
print(f"Final dataset shape: {df.shape}")

# Show target statistics
print(f"\n{TARGET_COLUMN} statistics:")
print(df[TARGET_COLUMN].describe())

In [None]:
# Visualize target distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(df[TARGET_COLUMN], bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel(TARGET_COLUMN)
axes[0].set_ylabel('Frequency')
axes[0].set_title(f'Distribution of {TARGET_COLUMN}')
axes[0].grid(alpha=0.3)

axes[1].boxplot(df[TARGET_COLUMN])
axes[1].set_ylabel(TARGET_COLUMN)
axes[1].set_title(f'Box Plot of {TARGET_COLUMN}')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Split data
train_df, temp_df = train_test_split(df, test_size=(1-train_ratio), random_state=42)
valid_size = valid_ratio / (valid_ratio + test_ratio)
valid_df, test_df = train_test_split(temp_df, test_size=(1-valid_size), random_state=42)

print(f"Training set: {len(train_df)} samples")
print(f"Validation set: {len(valid_df)} samples")
print(f"Test set: {len(test_df)} samples")

# Prepare data
smiles_col = 'canon_smiles'

X_train = train_df[smiles_col].to_list()
y_train = torch.tensor(train_df[TARGET_COLUMN].values, dtype=torch.float32)

X_valid = valid_df[smiles_col].to_list()
y_valid = torch.tensor(valid_df[TARGET_COLUMN].values, dtype=torch.float32)

X_test = test_df[smiles_col].to_list()
y_test = torch.tensor(test_df[TARGET_COLUMN].values, dtype=torch.float32)

## 7. Pre-Training Validation

In [None]:
import shutil

print("Running pre-training validation checks...\n")

# 1. GPU Check
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"  Memory: {gpu_mem_gb:.1f} GB")
else:
    print("‚ö†Ô∏è  No GPU available - training on CPU (slower)")

# 2. Data Check
assert DATA_FILE.exists(), f"‚ùå Data file not found: {DATA_FILE}"
print(f"‚úì Data file valid: {len(df)} samples")

# 3. Disk Space Check
disk_usage = shutil.disk_usage('.')
free_gb = disk_usage.free / 1e9
assert free_gb > 10, f"‚ùå Low disk space: {free_gb:.1f} GB"
print(f"‚úì Disk space: {free_gb:.1f} GB free")

# 4. Create checkpoint directory
CHECKPOINT_DIR = NOTEBOOK_DIR / 'checkpoints'
CHECKPOINT_DIR.mkdir(exist_ok=True)
print(f"‚úì Checkpoint directory: {CHECKPOINT_DIR}")

print("\n" + "="*60)
print("All validation checks passed! Ready to train.")
print("="*60)

## 8. Initialize MoE Model

In [None]:
# Define experts (4 per modality)
models = [
    smi_ted, smi_ted, smi_ted, smi_ted,                      # SMI-TED
    model_selfies, model_selfies, model_selfies, model_selfies,  # SELFIES
    mhg_gnn, mhg_gnn, mhg_gnn, mhg_gnn                        # MHG-GNN
]

# Initialize tokenizer
vocab_path = EXPERTS_DIR / 'smi_ted_light' / 'bert_vocab_curated.txt'
assert vocab_path.exists(), f"Vocab file not found: {vocab_path}"
tokenizer = MolTranBertTokenizer(str(vocab_path))

# IMPORTANT: Move all expert models to DEVICE BEFORE creating MoE
print("Moving expert models to device...")
smi_ted.to(DEVICE)
model_selfies.to(DEVICE)
mhg_gnn.to(DEVICE)
print(f"  SMI-TED device: {smi_ted.device}")
print(f"  SELFIES device: {model_selfies.device}")
print(f"  MHG-GNN device: {next(mhg_gnn.model.parameters()).device}")

# Initialize MoE
print("\nInitializing MoE model...")
moe_model = MoE(
    input_size=input_size,
    output_size=output_size,
    num_experts=num_experts,
    models=models,
    tokenizer=tokenizer,
    tok_emb=smi_ted.encoder.tok_emb,
    k=k,
    noisy_gating=True,
    verbose=False
).to(DEVICE)  # This also sets target device on all experts

# Initialize predictor network
net = Net(smiles_embed_dim=output_size, dropout=dropout, output_dim=output_dim)
net.apply(smi_ted._init_weights)
net = net.to(DEVICE)

print("\n" + "="*50)
print("‚úì All models initialized and on correct device")
print("="*50)
print(f"  MoE device: {next(moe_model.parameters()).device}")
print(f"  Net device: {next(net.parameters()).device}")



## 9. Training

In [None]:
from tqdm import tqdm

# Loss function and optimizer
loss_fn = nn.MSELoss()  # Mean Squared Error for regression

params = list(moe_model.parameters()) + list(net.parameters())
optim = torch.optim.AdamW(params, lr=learning_rate, weight_decay=1e-5)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode='min', factor=0.5, patience=10, verbose=True
)

# Data loader
train_loader = torch.utils.data.DataLoader(
    list(zip(X_train, y_train)),
    batch_size=batch_size,
    shuffle=True,
    num_workers=0  # Avoid multiprocessing issues
)

print(f"Training batches per epoch: {len(train_loader)}")

In [None]:
import sys
import torch
import pandas as pd
import numpy as np

# Retrieve the module containing the bug
# The module is imported as 'moe', not 'moe.moe' based on the project structure
moe_module = sys.modules.get("moe")

if moe_module and hasattr(moe_module, 'SparseDispatcher'):
    print("Patching SparseDispatcher.dispatch to fix GPU/CPU device mismatch...")

    # Define the corrected dispatch method
    def patched_dispatch(self, inp):
        inp = pd.Series(inp)

        # --- THE FIX ---
        # Ensure indices are moved to CPU/Numpy before passing to pandas iloc
        indices = self._batch_index
        if isinstance(indices, torch.Tensor):
            indices = indices.cpu().numpy()
        # ----------------

        # Original code continues...
        inp_exp = inp.iloc[indices]

        # Re-implementing the split logic seen in your traceback
        _part_indexes = [sum(self._part_sizes[:i]) for i in range(1, len(self._part_sizes))]
        return [list(x) for x in np.split(inp_exp.to_numpy(), _part_indexes, axis=0)]

    # Apply the patch to the class
    moe_module.SparseDispatcher.dispatch = patched_dispatch
    print("‚úì Patch applied successfully. You can now run the training loop.")
else:
    print("‚ùå Could not find 'moe' module or 'SparseDispatcher' class. Make sure you have run the setup cells and 'moe' is imported.")

In [None]:
import inspect
import sys
import numpy as np
import textwrap # Import textwrap for dedenting
from experts.selfies_ted.load import SELFIES

print("Patching SELFIES.encode to fix 'Column' object error...")

# 1. Get the source code of the failing function
source = inspect.getsource(SELFIES.encode)

# 2. Dedent the source code to remove class-level indentation
dedented_source = textwrap.dedent(source)

# 3. Define the problematic line and the fix
# The error comes from calling .copy() on a datasets Column object
buggy_line = 'emb = np.asarray(embedding["embedding"].copy())'
# The fix is to cast it to a list first, or pass it directly to asarray (which makes a copy)
fixed_line = 'emb = np.asarray(list(embedding["embedding"]))'

# 4. Replace the buggy line with the fixed line in the dedented source
if buggy_line in dedented_source:
    new_source = dedented_source.replace(buggy_line, fixed_line)

    # 5. Execute the new function code in the original module's context
    # This ensures it has access to necessary imports like 'Dataset' inside the module
    module_globals = sys.modules['experts.selfies_ted.load'].__dict__
    local_vars = {}
    exec(new_source, module_globals, local_vars)

    # 6. Replace the method on the class (the function is now in local_vars as 'encode')
    SELFIES.encode = local_vars['encode']
    print("‚úì Patch applied successfully.")
else:
    print("‚ö†Ô∏è Could not locate the exact buggy line. The file might differ from expectations.")

print("\nYou can now re-run the training loop cell.")

In [None]:
# Custom training loop with validation
train_losses = []
valid_losses = []
best_valid_loss = float('inf')

print("Starting training...\n")

for epoch in range(epochs):
    # Training
    moe_model.train()
    net.train()
    epoch_loss = 0
    
    for (x, y) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        y = y.to(DEVICE)  # CRITICAL FIX: Move y to device
        
        optim.zero_grad()
        
        # Forward pass
        embd, aux_loss = moe_model(x)
        y_hat = net(embd).squeeze()
        
        # Calculate loss
        loss = loss_fn(y_hat, y)
        total_loss = loss + aux_loss
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
        optim.step()
        
        epoch_loss += loss.item()
    
    avg_train_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation
    moe_model.eval()
    net.eval()
    with torch.no_grad():
        valid_embd, _ = moe_model(X_valid, verbose=False)
        valid_preds = net(valid_embd).squeeze()
        valid_loss = loss_fn(valid_preds.cpu(), y_valid).item()
    
    valid_losses.append(valid_loss)
    scheduler.step(valid_loss)
    
    # Save best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        checkpoint_path = CHECKPOINT_DIR / f'best_{MODEL_NAME}_moe_model.pt'
        torch.save({
            'epoch': epoch,
            'moe_state_dict': moe_model.state_dict(),
            'net_state_dict': net.state_dict(),
            'optimizer_state_dict': optim.state_dict(),
            'valid_loss': valid_loss,
        }, checkpoint_path)
    
    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch {epoch+1}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {valid_loss:.4f}")
        print(f"  Best Valid Loss: {best_valid_loss:.4f}\n")

print("\n‚úì Training completed!")
print(f"Best validation loss: {best_valid_loss:.4f}")
print(f"Model saved to: {CHECKPOINT_DIR / f'best_{MODEL_NAME}_moe_model.pt'}")

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', alpha=0.8)
plt.plot(valid_losses, label='Validation Loss', alpha=0.8)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title(f'Training History - {MODEL_NAME}')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

## 10. Evaluate on Test Set

In [None]:
# Load best model
checkpoint_path = CHECKPOINT_DIR / f'best_{MODEL_NAME}_moe_model.pt'
checkpoint = torch.load(checkpoint_path)
moe_model.load_state_dict(checkpoint['moe_state_dict'])
net.load_state_dict(checkpoint['net_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Checkpoint: {checkpoint_path}")

In [None]:
# Evaluate on test set
moe_model.eval()
net.eval()

with torch.no_grad():
    test_embd, _ = moe_model(X_test, verbose=False)
    test_preds = net(test_embd).squeeze()
    test_preds_np = test_preds.cpu().numpy()
    y_test_np = y_test.numpy()

# Calculate metrics
rmse = np.sqrt(mean_squared_error(y_test_np, test_preds_np))
mae = mean_absolute_error(y_test_np, test_preds_np)
r2 = r2_score(y_test_np, test_preds_np)

print("\n" + "="*50)
print(f"TEST SET RESULTS - {MODEL_NAME}")
print("="*50)
print(f"RMSE: {rmse:.4f}")
print(f"MAE:  {mae:.4f}")
print(f"R¬≤:   {r2:.4f}")
print("="*50)

In [None]:
# Parity plot
fig, ax = plt.subplots(figsize=(8, 8))

ax.scatter(y_test_np, test_preds_np, alpha=0.6, edgecolors='black', linewidth=0.5)

min_val = min(y_test_np.min(), test_preds_np.min())
max_val = max(y_test_np.max(), test_preds_np.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect prediction')

ax.set_xlabel(f'Actual {TARGET_COLUMN}', fontsize=12)
ax.set_ylabel(f'Predicted {TARGET_COLUMN}', fontsize=12)
ax.set_title(f'Parity Plot - {MODEL_NAME}\nRMSE={rmse:.3f}, R¬≤={r2:.3f}', fontsize=14)
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 11. Train XGBoost on MoE Embeddings

In [None]:
# Extract embeddings for XGBoost
print("Extracting embeddings...")
moe_model.eval()

with torch.no_grad():
    xgb_train, _ = moe_model(X_train, verbose=False)
    xgb_valid, _ = moe_model(X_valid, verbose=False)
    xgb_test, _ = moe_model(X_test, verbose=False)

xgb_train = xgb_train.cpu().numpy()
xgb_valid = xgb_valid.cpu().numpy()
xgb_test = xgb_test.cpu().numpy()

y_train_np = y_train.numpy()
y_valid_np = y_valid.numpy()

print(f"Train embeddings shape: {xgb_train.shape}")
print(f"Valid embeddings shape: {xgb_valid.shape}")
print(f"Test embeddings shape: {xgb_test.shape}")

In [None]:
# Train XGBoost
print("Training XGBoost...")
xgb_model = XGBRegressor(
    n_estimators=2000,
    learning_rate=0.05,
    max_depth=8,
    subsample=0.8,
    colsample_bytree=0.8,
    random_state=42,
    early_stopping_rounds=50,
    eval_metric='rmse'
)

xgb_model.fit(
    xgb_train, y_train_np,
    eval_set=[(xgb_valid, y_valid_np)],
    verbose=100
)

print("\n‚úì XGBoost training completed")

In [None]:
# Evaluate XGBoost
xgb_preds = xgb_model.predict(xgb_test)

xgb_rmse = np.sqrt(mean_squared_error(y_test_np, xgb_preds))
xgb_mae = mean_absolute_error(y_test_np, xgb_preds)
xgb_r2 = r2_score(y_test_np, xgb_preds)

print("\n" + "="*50)
print(f"XGBoost TEST SET RESULTS - {MODEL_NAME}")
print("="*50)
print(f"RMSE: {xgb_rmse:.4f}")
print(f"MAE:  {xgb_mae:.4f}")
print(f"R¬≤:   {xgb_r2:.4f}")
print("="*50)

# Save XGBoost model
xgb_model_path = CHECKPOINT_DIR / f'xgboost_{MODEL_NAME}_model.json'
xgb_model.save_model(str(xgb_model_path))
print(f"\n‚úì XGBoost model saved to: {xgb_model_path}")

## 12. Model Comparison

In [None]:
# Compare both approaches
comparison = pd.DataFrame({
    'Model': ['MOE + Net', 'MOE + XGBoost'],
    'RMSE': [rmse, xgb_rmse],
    'MAE': [mae, xgb_mae],
    'R¬≤': [r2, xgb_r2]
})

print("\n" + "="*60)
print(f"FINAL COMPARISON - {MODEL_NAME}")
print("="*60)
print(comparison.to_string(index=False))
print("="*60)

# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

metrics = ['RMSE', 'MAE', 'R¬≤']
for idx, metric in enumerate(metrics):
    axes[idx].bar(comparison['Model'], comparison[metric],
                  color=['steelblue', 'coral'], edgecolor='black', alpha=0.7)
    axes[idx].set_ylabel(metric, fontsize=12)
    axes[idx].set_title(f'{metric} Comparison', fontsize=12)
    axes[idx].grid(alpha=0.3, axis='y')
    
    for i, v in enumerate(comparison[metric]):
        axes[idx].text(i, v, f'{v:.3f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

# Save comparison
comparison_path = CHECKPOINT_DIR / f'comparison_{MODEL_NAME}.csv'
comparison.to_csv(comparison_path, index=False)
print(f"\n‚úì Comparison saved to: {comparison_path}")

## 13. Make Predictions on New Molecules

In [None]:
# Example: Predict on new SMILES
new_smiles = [
    'CCO',        # Ethanol
    'CC(=O)O',    # Acetic acid
    'c1ccccc1',   # Benzene
]

# Normalize
new_smiles_canon = [normalize_smiles(s) for s in new_smiles]

# Predict using MOE+Net
moe_model.eval()
net.eval()
with torch.no_grad():
    new_embd, _ = moe_model(new_smiles_canon, verbose=False)
    new_preds = net(new_embd).squeeze().cpu().numpy()

# Predict using XGBoost
xgb_new_preds = xgb_model.predict(new_embd.cpu().numpy())

# Display results
results_df = pd.DataFrame({
    'SMILES': new_smiles,
    'MOE+Net Prediction': new_preds,
    'XGBoost Prediction': xgb_new_preds
})

print(f"\nPredictions for {TARGET_COLUMN}:")
print(results_df.to_string(index=False))

## üéâ Training Complete!

**What you've accomplished:**
- ‚úÖ Trained MoL-MoE model on custom dataset
- ‚úÖ Trained XGBoost on MoE embeddings
- ‚úÖ Evaluated both models on test set
- ‚úÖ Saved checkpoints and comparisons

**Next steps:**
1. Review model performance metrics
2. Try different hyperparameters if needed
3. Use trained models for predictions on new data
4. Compare with other baseline models

**Files saved:**
- `checkpoints/best_{MODEL_NAME}_moe_model.pt` - Best MoE model
- `checkpoints/xgboost_{MODEL_NAME}_model.json` - XGBoost model
- `checkpoints/comparison_{MODEL_NAME}.csv` - Performance comparison