# Property Prediction using MIST checkpoints

<a target="_blank" href="https://colab.research.google.com/github/BattModels/mist-demo/blob/main/tutorials/molecular_property_prediction.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

``MIST`` is a suite of molecular Foundation Models with expanded coverage of chemical space trained using [``Smirk``](https://arxiv.org/abs/2409.15370), a novel tokenization scheme which captures a comprehensive representation of molecular structure including nuclear, electronic, and geometric features.

Motivated by scaling trends in NLP, the largest ``MIST`` models were trained with an order of magnitude more parameters and data than prior work, matching or exceeding the state-of-the-art across diverse chemical benchmarks.

This noteboook will walk through loading finetuned ``MIST`` checkpoints and evaluating properties for molecules of interest.

In [None]:
! pip install smirk

In [None]:
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from transformers import AutoModel, AutoTokenizer
from smirk import SmirkTokenizerFast

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(f"Using device: {device}")

## Analyzing Trends in Characteristic Temperatures

We'll use three finetuned MIST models to predict boiling point, flash point, and melting point for alkenes and alcohols of varying chain lengths.

In [None]:
models_info = {
    'Boiling Point': 'mist-models/mist-26.9M-b302p09x-bp',
    'Flash Point': 'mist-models/mist-26.9M-cyuo2xb6-fp',
    'Melting Point': 'mist-models/mist-26.9M-y3ge5pf9-mp'
}

loaded_models = {}
for prop_name, model_path in models_info.items():
    print(f"Loading {prop_name} model...")
    loaded_models[prop_name] = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device).eval()


In [None]:
# Generate SMILES for alkenes and alcohols with varying chain lengths
def generate_alkene_smiles(n):
    """Generate linear 1-alkene SMILES for n carbons (C=C at terminal)"""
    if n == 2:
        return "C=C"
    elif n == 3:
        return "C=CC"
    else:
        return "C=C" + "C" * (n - 2)

def generate_alcohol_smiles(n):
    """Generate linear 1-alcohol SMILES for n carbons (OH at terminal)"""
    if n == 1:
        return "CO"
    else:
        return "C" * (n - 1) + "O"

# Generate molecules for chain lengths 2-10
chain_lengths = range(2, 25, 2)

alkenes = {
    'name': [f'C{n} Alkene' for n in chain_lengths],
    'carbons': list(chain_lengths),
    'smiles': [generate_alkene_smiles(n) for n in chain_lengths]
}

alcohols = {
    'name': [f'C{n} Alcohol' for n in chain_lengths],
    'carbons': list(chain_lengths),
    'smiles': [generate_alcohol_smiles(n) for n in chain_lengths]
}

In [None]:
# Predict for alkenes
alkene_predictions = {}
for prop_name, model in loaded_models.items():
    print(f"  Predicting {prop_name} for alkenes...")
    predictions = model.predict(alkenes['smiles'])
    prop_key = list(predictions.keys())[0]
    alkene_predictions[prop_name] = predictions[prop_key]['value'].cpu().numpy().tolist()

# Predict for alcohols
alcohol_predictions = {}
for prop_name, model in loaded_models.items():
    print(f"  Predicting {prop_name} for alcohols...")
    predictions = model.predict(alcohols['smiles'])
    prop_key = list(predictions.keys())[0]
    alcohol_predictions[prop_name] = predictions[prop_key]['value'].cpu().numpy().tolist()

In [None]:
fig = make_subplots(
    rows=3, cols=1,
    subplot_titles=('Boiling Point', 'Flash Point', 'Melting Point'),
    horizontal_spacing=0.1
)

properties = ['Boiling Point', 'Flash Point', 'Melting Point']
colors = {'Alkenes': '#1f77b4', 'Alcohols': '#ff7f0e'}

# Plot each property
for idx, prop_name in enumerate(properties, start=1):    
    # Alkenes trace
    alkene_vals = alkene_predictions[prop_name]
    fig.add_trace(
        go.Scatter(
            x=alkenes['carbons'],
            y=alkene_vals,
            mode='lines+markers',
            name='Alkenes',
            marker=dict(size=10, color=colors['Alkenes']),
            line=dict(width=3, color=colors['Alkenes']),
            showlegend=(idx == 1),
        ),
        row=idx, col=1
    )
    
    # Alcohols trace
    alcohol_vals = alcohol_predictions[prop_name]
    fig.add_trace(
        go.Scatter(
            x=alcohols['carbons'],
            y=alcohol_vals,
            mode='lines+markers',
            name='Alcohols',
            marker=dict(size=10, color=colors['Alcohols']),
            line=dict(width=3, color=colors['Alcohols']),
            showlegend=(idx == 1),
        ),
        row=idx, col=1
    )

    fig.update_yaxes(title_text="Temperature (°C)", row=idx, col=1)

fig.update_xaxes(title_text="Carbon Chain Length", row=4, col=1)

# Update layout
fig.update_layout(
    height=600,
    template='simple_white',
    hovermode='closest',
)

fig.show()

The model extrapolates chemically reasonable trends!
- With increasing chain length all three properties go up for both alkenes and alcohols. That’s what you expect as molar mass and dispersion forces increase.
- Alcohols are always above alkenes for boiling point, flash point, and melting point, which matches the stronger intermolecular interactions (H-bonding) in alcohols
- For a given molecule you have $T_m$ < $T_{flash}$ < $T_b$, which is physically reasonable.