# Fine-tuning MIST Encoder Models

<a target="_blank" href="https://colab.research.google.com/github/BattModels/mist-demo/blob/main/tutorials/run_finetuning.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

The core advantage of a foundation model is that it can be adapted to a wide range of downstream tasks given a small number of labelled examples.
We have demonstrated the MIST models' efficacy as scientific foundation models by fine-tuning variants of MIST to predict over 400 properties --- including quantum mechanical, thermodynamic, biochemical, and psychophysical properties --- from a molecule’s SMILES representation.
The encoders are fine-tuned on single molecule property prediction (classification and regression) tasks by attaching a small two-layer MLP. 

This tutorial demonstrates how to fine-tune MIST encoder models for downstream molecular property prediction tasks.
As an examples, we will finetune a MIST encoder to predict LUMO (Lowest Unoccupied Molecular Orbital) energies from the QM9 dataset.

In [None]:
# Install dependencies
# ! pip install git+https://github.com/BattModels/mist-demo.git -q

In [None]:
# Import dependencies
import torch
from rdkit import Chem
from smirk import SmirkTokenizerFast
from datasets import load_dataset
from transformers import (
    AutoModel,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from peft import LoraConfig, get_peft_model
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from optuna.visualization import plot_optimization_history

from mist_demo.finetuning.optimize_hyperparams import tune_hyperparameters
from mist_demo.finetuning.regression_model import RegressionModel

# Set device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(f"Using device: {device}")

## Data Preparation 

MIST models were pretrained on *kekulized SMILES*, which is a canonical form that explicitly represents alternating single and double bonds in aromatic rings. We need to kekulize our input SMILES strings before tokenization.

In [None]:
def kekulize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    Chem.Kekulize(mol)
    return Chem.MolToSmiles(mol, kekuleSmiles=True)


# Test the function
test_smiles = "c1ccccc1"  # Benzene with aromatic notation
kekulized = kekulize_smiles(test_smiles)
print(f"Original: {test_smiles}")
print(f"Kekulized: {kekulized}")

We'll create a function to tokenize SMILES strings with proper kekulization:

In [None]:
def tokenize_function(examples, tokenizer):
    kekulized = [kekulize_smiles(s) for s in examples["smiles"]]
    return tokenizer(
        kekulized,
        padding="max_length",
        max_length=512,
    )

In [None]:
# Load dataset from CSV
dataset = load_dataset(
    "csv",
    data_files={"train": "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv"},
)["train"]

print(f"Total examples: {len(dataset)}")
print(f"Features: {dataset.features}")
print(f"\nFirst example:")
print(dataset[0])

# Split into train/test (80/20)
dataset = dataset.train_test_split(test_size=0.2, seed=42)

print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

In [None]:
# Tokenize dataset
# MIST models were trained using the Smirk tokenizer

tokenizer = SmirkTokenizerFast()
tokenized_dataset = dataset.map(
    tokenizer,
    input_columns=["smiles"],
    desc="Tokenizing",
)

# Rename target column to 'labels' (expected by Trainer)
tokenized_dataset = tokenized_dataset.rename_column("lumo", "labels")

# Set format for PyTorch
tokenized_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "labels"],
)

print("Tokenization complete!")
tokenized_dataset["train"].to_pandas().head()

# Load the pretrained MIST encoder 

We'll load a pretrained MIST encoder from HuggingFace Hub. You can choose from two different model sizes:
- `mist-models/mist-1.8B-dh61satt` (1.8B parameters)
- `mist-models/mist-28M-ti624ev1` (1.8B parameters)

In [None]:
model_path = "mist-models/mist-28M-ti624ev1"

print(f"Loading model: {model_path}")
encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True)

# This will be much faster and use less memory than the full model!
print(f"Encoder hidden size: {encoder.config.hidden_size}")
print(f"Number of encoder parameters: {sum(p.numel() for p in encoder.parameters()):,}")

# LoRA for efficient fine-tuning
LoRA (Low-Rank Adaptation) enables efficient fine-tuning by adding small trainable adapter layers while keeping the base model frozen. This dramatically reduces training time and memory usage.

In [None]:
# Configure LoRA
lora_config = LoraConfig(
    r=8,                              # Rank of LoRA matrices
    lora_alpha=16,                    # Scaling factor
    target_modules=["query", "value"], # Apply LoRA to attention layers
    lora_dropout=0.1,                 # Dropout for LoRA layers
    bias="none",                      # Don't train bias terms
    task_type=None,                   # Custom task type
)

# Apply LoRA to encoder
encoder = get_peft_model(encoder, lora_config)

# Print trainable parameters
encoder.print_trainable_parameters()

# Defining an architecture for fine-tuning
We will create a `RegressionModel` that combines:
- A pretrained MIST encoder
- A MLP (multi-layer perceptron) regression head with tunable architecture

### Hyperparameter Tuning for Task Network

Before training, we can use Bayesian optimization to find optimal hyperparameters for the task network. We'll tune:
- **Dropout rate**: Regularization strength
- **Number of hidden layers**: Model complexity
- **Learning rate**: Optimization step size
- **Batch size**: Training batch size

We will use [Optuna's TPE (Tree-structured Parzen Estimator)](https://hub.optuna.org/samplers/tpe_tutorial/) sampler for efficient Bayesian optimization:

In [None]:
# Run hyperparameter tuning
# This will take some time - adjust n_trials based on your compute budget

best_params, study = tune_hyperparameters(
    encoder=encoder,
    tokenizer=tokenizer,
    train_dataset=tokenized_dataset["train"],
    val_dataset=tokenized_dataset["test"],
    device=device,
    n_trials=5  # Increase for better results
)

# If you don't want to run hyperparameter tuning
# we'll these use default parameters
# best_params = {
#     'task_hidden_size': 512,
#     'dropout': 0.1,
#     'num_hidden_layers': 1,
#     'learning_rate': 1.6e-4,
#     'batch_size': 32
# }

print("Using hyperparameters:")
for key, value in best_params.items():
    print(f"  {key}: {value}")

### Visualize Optimization Results

Let's analyze the impact of hyperparameters on the model's validation loss.

In [None]:
# Visualize optimization history
fig1 = plot_optimization_history(study)
fig1.update_layout(template='simple_white', height = 400, width = 700)
fig1.update_yaxes(type="log")
fig1.show()

In [None]:
# Create custom visualization of all trials
trials_df = study.trials_dataframe()

fig = px.parallel_coordinates(
    trials_df,
    dimensions=['params_dropout', 'params_num_hidden_layers', 
                'params_learning_rate', 'params_batch_size', 'value'],
    color='value',
    color_continuous_scale='Viridis',
    labels={
        'params_task_hidden_size': 'Hidden Size',
        'params_dropout': 'Dropout',
        'params_num_hidden_layers': '# Layers',
        'params_learning_rate': 'Learning Rate',
        'params_batch_size': 'Batch Size',
        'value': 'Val Loss'
    }
)
fig.update_layout(
    template='simple_white',
    height=400
)
fig.show()

In [None]:
# Create model with task head using tuned hyperparameters
model = RegressionModel(
    encoder=encoder,
    hidden_size=encoder.config.hidden_size,
    dropout=best_params['dropout'],
    num_hidden_layers=best_params['num_hidden_layers']
)


# Optional: Freeze encoder parameters
# Only task head will be trained
# for param in model.encoder.parameters():
#     param.requires_grad = False

model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"  - LoRA parameters: {sum(p.numel() for p in encoder.parameters() if p.requires_grad):,}")
print(f"  - Task head parameters: {sum(p.numel() for p in model.task_head.parameters()):,}")

# Fine-tune the model!

In [None]:
# Setup training arguments using tuned hyperparameters
training_args = TrainingArguments(
    output_dir="./finetuned_model",
    num_train_epochs=10,
    per_device_train_batch_size=best_params['batch_size'],
    per_device_eval_batch_size=best_params['batch_size'],
    learning_rate=best_params['learning_rate'],
    warmup_ratio=0.1,
    logging_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    save_total_limit=2,
    report_to="none",  # Disable wandb/tensorboard
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
)

data_collator = DataCollatorWithPadding(tokenizer)

print("Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Warmup ratio: {training_args.warmup_ratio}")
print(f"  Device: {device}")

In [None]:
# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Train!
print("Starting training...")
print(f"Total training steps: {len(tokenized_dataset['train']) // best_params['batch_size'] * training_args.num_train_epochs}")
print(f"Evaluation every: {len(tokenized_dataset['train']) // best_params['batch_size']} steps\n")

trainer.train()

print("Training complete!")

In [None]:
# Save the finetuned model
save_path = "./finetuned_model"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")

# Evaluation 

Let's evaluate the model

In [None]:
from torch.utils.data import DataLoader

# Prepare test subset and dataloader
test_subset = tokenized_dataset["test"].select(range(min(1000, len(tokenized_dataset["test"]))))
test_dataloader = DataLoader(test_subset, batch_size=32, collate_fn=data_collator)

# Collect predictions and labels
test_predictions, test_labels = [], []

model.eval()
with torch.no_grad():
    for batch in test_dataloader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        preds = model(**inputs)["y_pred"].squeeze(-1).cpu().numpy()
        
        test_predictions.extend(preds)
        test_labels.extend(batch['labels'].numpy())

test_predictions = np.array(test_predictions)
test_labels = np.array(test_labels)

# Calculate metrics
mae = np.mean(np.abs(test_predictions - test_labels))
rmse = np.sqrt(np.mean((test_predictions - test_labels)**2))
r2 = 1 - np.sum((test_labels - test_predictions)**2) / np.sum((test_labels - test_labels.mean())**2)

print(f"Test Set Performance (n={len(test_predictions)}):")
print(f"  MAE:  {mae:.4f} Hartree")
print(f"  RMSE: {rmse:.4f} Hartree")
print(f"  R²:   {r2:.4f}")

In [None]:
fig = go.Figure()

# Scatter plot of predictions vs actual
fig.add_trace(go.Scatter(
    x=test_labels,
    y=test_predictions,
    mode='markers',
    name='Predictions',
    marker=dict(
        size=6,
        color=np.abs(test_predictions - test_labels),
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title="Absolute Error"),
        opacity=0.6,
        line=dict(width=0.5, color='white')
    ),
))

# Add perfect prediction line (y = x)
fig.add_trace(go.Scatter(
    x=test_labels,
    y=test_labels,
    mode='lines',
    name='Perfect Prediction',
    line=dict(color='red', width=2, dash='dash'),
    hoverinfo='skip'
))

# Update layout
fig.update_layout(
    xaxis_title='Actual LUMO (Hartree)',
    yaxis_title='Predicted LUMO (Hartree)',
    template='simple_white',
    height=600,
    width=700,
    showlegend=False,
)

# Make axes equal
fig.update_xaxes(scaleanchor="y", scaleratio=1)
fig.update_yaxes(scaleanchor="x", scaleratio=1)

fig.show()