# Getting Started

<a href="https://colab.research.google.com/github/BattModels/smirk/blob/main/docs/smirk_demo.ipynb">
    <img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg">
</a>
<a href="https://mybinder.org/v2/gh/BattModels/smirk/main?urlpath=%2Fdoc%2Ftree%2Fdocs%2Fsmirk_demo.ipynb">
    <img alt="Binder" src="https://mybinder.org/badge_logo.svg">
</a>

Molecular Foundation Models are demonstrating impressive performance, but current models use tokenizers that fail to represent *all* of chemistry; inherently limiting their performance. In particular,  [Atom-wise] tokenizers emit a single token for any [bracketed atom], triggering a combinatorial exposition of the vocabulary size. Capturing all variants of Carbon atoms would require 75,600 tokens, or nearly a quarter of the GPT-4o's vocabulary ([Wadell et al.][paper]).

The problem is that most atoms are bracketed. Any element outside the organic subset, chiral centers, isotopes, or charged species are all encoded as bracketed atoms. Bracketed atoms encode the nuclear, electronic, and geometric features that are critical to numerous widely-used compounds, including:

- [Cisplatin]: An effective chemotherapy drug on the World's Health Organizations List of Essential Medicines.
    However, its isomer [Transplatin] is not an effective drug.
- [Sodium pertechnetate]: A [radiopharmaceutical] used for thyroid imaging.
- [Lithium Iron Phosphate]: A widely used cathode material for batteries powering everything from consumer electronics to electric vehicles. 

Smirk fixes this by tokenizing SMILES encodings all the way down to their constituent elements.
Enabling the complete coverage of [OpenSMILES] with a vocabulary of *167* tokens.

Check out the [paper] for all the details; otherwise, let's see it in action!

[OpenSMILES]: http://opensmiles.org/
[paper]: https://doi.org/10.48550/arXiv.2409.15370
[Atom-wise]: https://doi.org/10.1039/C8SC02339E
[bracketed atom]: https://en.wikipedia.org/wiki/Simplified_Molecular_Input_Line_Entry_System#Atoms

[Cisplatin]: https://en.wikipedia.org/wiki/Cisplatin
[Transplatin]: https://en.wikipedia.org/wiki/Transplatin
[Sodium pertechnetate]: https://en.wikipedia.org/wiki/Sodium_pertechnetate
[radiopharmaceutical]: https://en.wikipedia.org/wiki/Radiopharmaceutical
[Lithium Iron Phosphate]: https://en.wikipedia.org/wiki/Lithium_iron_phosphate

🐍 Installation is easy with pre-build binaries on [PyPI](https://pypi.org/project/smirk/) and [GitHub](https://github.com/BattModels/smirk/releases). Just run: `pip install smirk`

> Installing from source? See [installing from source](./developer.md#installing-from-source) for instructions.

In [None]:
!python -m pip install smirk transformers rdkit

## First steps

🤗 smirk subclasses Hugging Face's [PreTrainedTokenizerBase](#transformers.PreTrainedTokenizerBase) for seamless compatibility and leverages [Tokenizers] for raw rust-powered speed. No need to learn another framework; everything works out of the box 🎁

[Tokenizers]: https://huggingface.co/docs/tokenizers/index

In [None]:
from smirk import SmirkTokenizerFast

# Just import and tokenize!
smirk = SmirkTokenizerFast()
smirk("CC(=O)Nc1ccc(O)cc1")

In [None]:
# Batch Tokenization with Padding
batch = smirk([
    "C[C@@H]1CCCCCCCCCCCCC(=O)C1",
    "O=C(O)C[C@H](N)C(=O)N[C@H](C(=O)OC)Cc1ccccc1",
    "CN(C)S[N][Re@OH18]([C][O])([C][O])([C][O])([C][O])[C][O]"
], padding="longest")
batch

In [None]:
# Back to molecules!
smirk.batch_decode(batch["input_ids"], skip_special_tokens=True)

In [None]:
# By default, we don't add `[CLS]` and `[SEP]` tokens, but that's just a flag
smirk_bert = SmirkTokenizerFast(template="[CLS] $0 [SEP]")
" ".join(smirk_bert.tokenize("CNCCC(c1ccccc1)Oc2ccc(cc2)C(F)(F)F", add_special_tokens=True))

## What Makes Smirk Special?

By fully decomposing the input molecule, `smirk` ensures complete coverage of the [OpenSMILES] specification. Any valid [OpenSMILES] encoding can be tokenized by `smirk` without emitting unknown tokens. Moreover, for non-bracketed atoms, the `smirk` tokenization is the same as an Atomwise tokenizer used by current molecular foundation models such as [MoLFormer].

[OpenSMILES]: http://opensmiles.org/
[MoLFormer]: https://doi.org/10.1038/s42256-022-00580-7

In [None]:
from rdkit import Chem
from rdkit.Chem.Draw import MolsToGridImage, rdMolDraw2D
from IPython.display import SVG
from transformers import AutoTokenizer

# Tokenizers being evaluated, see the paper for a more comphrensive study (30 tokenizers!)
# Or try adding one of the other tokenziers evaluated in the paper
tokenizers = {
    "smirk": smirk,
    "molformer": AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True),
    "GPT-4o": AutoTokenizer.from_pretrained("Xenova/gpt-4o"),
}

smi = [
    "Cl[Pt@SP1](Cl)([NH3])[NH3]", # Cisplatin 
    "Cl[Pt@SP2](Cl)([NH3])[NH3]", # Transplatin
    "CN1C=NC2=C1C(=O)N(C(=O)N2C)C", # Caffeine
    "[O-][99Tc](=O)(=O)=O.[Na+]", # Sodium pertechnetate with radiotracer
    "[Ga+]$[As-]", # Gallium arsenide
    "[OH2]", # Water
]

def get_legend(smi:str, tokenizers:dict):
    """Helper function for creating legends"""
    entries = []
    for name, tok in tokenizers.items():
        entries.append(f"{name}: {' '.join(tok.tokenize(smi))}")
    return "\n".join(entries)

# Draw all molecules and tokenizations on a grid
drawOptions = rdMolDraw2D.MolDrawOptions()
drawOptions.fixedScale = 1
drawOptions.centreMoleculesBeforeDrawing = True
drawOptions.minFontSize = 6
drawOptions.legendFontSize = 24
drawOptions.legendFraction = 0.3
MolsToGridImage(
    [Chem.MolFromSmiles(smi) for smi in smi],
    molsPerRow=2, subImgSize=(400,200),
    legends=[get_legend(smi, tokenizers) for smi in smi],
    drawOptions=drawOptions,
)

Smirk tokenized all molecules without a single unknown, whereas MoLFormer's Atomwise tokenizer emitted the unknown token for both [Cisplatin] and [Transplatin] (First row). Conversely, the Atomwise tokenizer emitted unknown tokens for the following:

- Platinum chiral centers: `[Pt@SP1]` and `[Pt@SP2]`
- Ammonia & Water with explicit hydrogens: `[NH3]` and `[OH2]`
- Gallium ion: `[Ga+]`
- Quadbond: `$`

As a data-driven method, Atomwise tokenizers only know about the atoms seen during their training; fundamentally limiting their generalization ability.

[Transplatin]: https://en.wikipedia.org/wiki/Transplatin
[Cisplatin]: https://en.wikipedia.org/wiki/Cisplatin
[Gallium arsenide]: https://en.wikipedia.org/wiki/Gallium_arsenide

## Zero to Molecular Foundation Model with Smirk!

Let's train a small [RoBERTa] model on molecules from [QM9] using Hugging Face and smirk.

[QM9]: https://doi.org/10.1021/ja902302h
[RoBERTa]: https://doi.org/10.48550/ARXIV.1907.11692

In [None]:
!python -m pip install accelerate datasets torch

### Dataset Preprocessing

In [None]:
from datasets import load_dataset

# MoleculeNet's QM9 dataset. Normally this would be a larger (and unlabeled)
# dataset. But for a demo, it's perfect
dataset = load_dataset("csv", 
    data_files=["https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv"],
)["train"].select_columns("smiles").train_test_split(test_size=0.2)

# Tokenizer the splits! For a larger dataset, this would be done on-the-fly
dataset = dataset.map(smirk, input_columns=["smiles"], desc="Tokenizing")

> 💡 huggingface/tokenizers may raise a warning about being forked as we've already used our tokenizers (this isn't a smirk issue).
> It's harmless, but when actually training it's best to avoid tokenization until after the fork to benefit from the rust-level parallelism

🎉 That's it! We've tokenized all of QM9 using smirk!

In [None]:
dataset["train"].to_pandas().head()

### Training
Once we've tokenized the dataset, training the model is just a matter of configuration.

In [None]:
from accelerate import Accelerator
from transformers import Trainer, TrainingArguments, RobertaForMaskedLM, RobertaConfig, DataCollatorForLanguageModeling

# A very small model for demonstrating training a molecular foundation model with smirk 
config = RobertaConfig(
    vocab_size=len(smirk),
    hidden_size=256,
    intermediate_size=1024,
    num_hidden_layers=4,
    num_attention_heads=4,
)
model = RobertaForMaskedLM(config)

# Setup up the trainer to use our dataset
trainer = Trainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=smirk,
    data_collator=DataCollatorForLanguageModeling(smirk), # The data collator needs to know about our tokenizer
)

In [None]:
trainer.train()