# Property Prediction using MIST checkpoints

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/Ordinary_Differential_Equation_Solving_using_deepchem.ipynb)

``MIST`` is a suite of molecular Foundation Models with expanded coverage of chemical space trained using [``Smirk``](https://arxiv.org/abs/2409.15370), a novel tokenization scheme which captures a comprehensive representation of molecular structure including nuclear, electronic, and geometric features.

Motivated by scaling trends in NLP, the largest ``MIST`` models were trained with an order of magnitude more parameters and data than prior work, matching or exceeding the state-of-the-art across diverse chemical benchmarks.

**This noteboook will walk through loading finetuned ``MIST`` checkpoints and evaluating properties for molecules of interested to you.**

In [None]:
import os
import json
import torch
import pandas as pd
from smirk import SmirkTokenizerFast
from mist_demo import models
from mist_demo.models.utils import load_model


if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

def load_checkpoint(path):
    config = os.path.join(path, "config.json")
    with open(config, 'r') as file:
        config = json.load(file)
    model_class = eval(f"models.{config['architectures'][0]}")
    return model_class.from_pretrained(path).to(device)

# model = MISTMultiTask.from_pretrained(".").eval().to(device)
# tok = SmirkTokenizerFast()

# smi = [
#     "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
#     "CCN(CC)C(=O)[C@H]1CN([C@@H]2Cc3c[nH]c4c3c(ccc4)C2=C1)C",
#     "CCC(=O)OC1(C(CC2C1(CC(C3(C2CC(C4=CC(=O)C=CC43C)F)F)O)",
#     "CN3[C@H]1CC[C@@H]3C[C@@H](C1)OC(=O)C(CO)c2cc",
#     "CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O",
# ]

# pred = model.predict(smi)
# df = pd.DataFrame({k : v["value"] for k, v in pred.items()}, index=smi)

# df.to_csv("pred.csv")

path = "models/mist-1.8B-3fbbz4is-h298"
model = load_checkpoint(path)
smi = [
    "CN1C=NC2=C1C(=O)N(C(=O)N2C)C",
    "CCN(CC)C(=O)[C@H]1CN([C@@H]2Cc3c[nH]c4c3c(ccc4)C2=C1)C",
    "CCC(=O)OC1(C(CC2C1(CC(C3(C2CC(C4=CC(=O)C=CC43C)F)F)O)",
    "CN3[C@H]1CC[C@@H]3C[C@@H](C1)OC(=O)C(CO)c2cc",
    "CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O",
]

pred = model.predict(smi)