In [11]:
import os

import torch
import pytorch_lightning as pl

from mist.models.roberta_base import RoBERTa
from mist.data_modules.roberta_dataset import RobertaDataSet
from mist.utils.lr_schedule import RelativeCosineWarmup

# enable RUST based parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "true"

ImportError: cannot import name 'RobertaDataSet' from 'mist.data_modules.roberta_dataset' (/home/abhutani/mist/mist/data_modules/roberta_dataset.py)

In [8]:
from mist.models import roberta_base

In [10]:
roberta_base.__dict__

{'__name__': 'mist.models.roberta_base',
 '__doc__': None,
 '__package__': 'mist.models',
 '__loader__': <_frozen_importlib_external.SourceFileLoader at 0x14bb245aff10>,
 '__spec__': ModuleSpec(name='mist.models.roberta_base', loader=<_frozen_importlib_external.SourceFileLoader object at 0x14bb245aff10>, origin='/home/abhutani/mist/mist/models/roberta_base.py'),
 '__file__': '/home/abhutani/mist/mist/models/roberta_base.py',
 '__cached__': '/home/abhutani/mist/mist/models/__pycache__/roberta_base.cpython-311.pyc',
 '__builtins__': {'__name__': 'builtins',
  '__doc__': "Built-in functions, types, exceptions, and other objects.\n\nThis module provides direct access to all 'built-in'\nidentifiers of Python; for example, builtins.len is\nthe full name for the built-in function len().\n\nThis module is not normally accessed explicitly by most\napplications, but can be useful in modules that provide\nobjects with the same name as a built-in value, but in\nwhich the built-in of that name is a

### Pre-training Dataset

We use a subset of randomly sampled molecules from [Enamine’s REAL Space Chemical Library](https://enamine.net/compound-collections/real-compounds/real-space-navigator), which is currently the largest library of commercially available compounds with 48B virtual products based on ~0.1M reagents and building blocks and 166 defined chemical rules to combine them. 

This pre-training dataset covers a significant fraction of the space of possible molecules. The plot below visualizes the chemical space covered by the pre-training dataset using the [TMAP](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-0416-x) (Tree Manifold Approximation and Projection) algorithm and compare it to the chemical space convered by datasets in MoleculeNet. MoleculeNet is a popular cheminformatics benchmark and is representative of datasets typically used to train machine learning models for chemistry.

<img src="figures/MIST_TMAP.png" alt="tmap" width="50%" display="block" margin-left="auto;" margin-right="auto;">


The molecules are stored as SMILES (Simplified Molecular-Input Line-Entry System) strings. SMILES are a cheminformatic line notation for describing chemical structures using short ASCII strings. SMILES strings are like a connection table in that they identify the nodes and edges of a molecular graph. In SMILES, hydrogen are typically implicitly implied and atoms are represented by their atomic symbol enclosed in brackets unless they are elements of the “organic subset” (`B`, `C`, `N`, `O`, `P`, `S`, `F`, `Cl`,`Br`, and `I`), which do not require brackets unless they are charged. So gold would be `[Au]` but chlorine would be just `Cl`. If hydrogens are explicitly implied brackets are used. A formal charge is represented by one of the symbols `+` or `-`. Single, double, triple, and aromatic bonds are represented by the symbols, `-`, `=`, `#`, and `:`, respectively. Single and aromatic bonds may be, and usually are, omitted. Below is an example of a SMILES string and the corresponding 2D molecular graph.

<img src="figures/smiles.png" alt="smiles" width="50%" display="block" margin-left="auto;" margin-right="auto;">


In [3]:
data_path = "./sample_data/"
tokenizer = "ibm/MoLFormer-XL-both-10pct"
mlm_probability = 0.15 
batch_size = 64
val_batch_size = 1

datamodule = RobertaDataSet(
    path=pretraining_data_path,
    tokenizer=tokenizer,
    batch_size=batch_size,
    val_batch_size=val_batch_size
)



### Initialize Model

In [5]:
vocab_size = datamodule.tokenizer.vocab_size
max_position_embeddings = 512
num_attention_heads = 12
num_attention_heads = 6
num_hidden_layers = 6
hidden_size = 768
intermediate_size = 768
relative_cosine_scheduler = lambda optimizer: RelativeCosineWarmup(optimizer, num_warmup_steps="beta2", num_training_steps=50_000)

model = RoBERTa(
    vocab_size=vocab_size,
    max_position_embeddings=max_position_embeddings,
    num_attention_heads=num_attention_heads,
    num_hidden_layers=num_hidden_layers,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    optimizer = torch.optim.AdamW,
    lr_schedule = relative_cosine_scheduler
)


### Initialize Trainer

In [6]:
# Some callbacks are defined for convinience
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

lr_monitor = LearningRateMonitor(logging_interval="step") # monitors and logs learning rate for schedulers during training

checkpoint_callback = ModelCheckpoint(save_last="link",
        filename="epoch={epoch}-step={step}-val_loss={val/loss_epoch:.2f}",
        monitor="val/loss_epoch",
        save_top_k=5,
        verbose=True,
        auto_insert_metric_name=False
    ) # saves the best model during training based on validation loss


## Training!

The pre-training strategy we use is analogous the MLM (Masked Language Modeling) used in NLP (Natural Language Processing). 
Part of the SMILES string is replace with a 'mask'. The objective is a cross-entropy loss on predicting the masked tokens.

<img src="figures/MIST_pretraining.png" alt="tmap" width="50%" display="block" margin-left="auto;" margin-right="auto;">

In [11]:
trainer = pl.Trainer(
    precision = "16-mixed", # Combines FP32 and lower-bit floating points to reduce memory footprint and increase performance.
    strategy = "ddp_notebook", # Distributed Data Parallel training.
    use_distributed_sampler = False,  # Handled by DataModule (needed due to IterableDataset).
    limit_train_batches=500, 
    limit_val_batches=10, 
    max_epochs=1, 
    devices=torch.cuda.device_count(),
    callbacks=[lr_monitor, checkpoint_callback])

trainer.fit(model=model, datamodule=datamodule)


/home/abhutani/electrolyte_fm/.venv/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/abhutani/electrolyte_fm/.venv/lib/python3.11/s ...
Using 16bit Automatic Mixed Precision (AMP)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch

Epoch 0:   0%|                                                           | 0/500 [00:00<?, ?it/s]



Epoch 0: 100%|████████████████| 500/500 [01:15<00:00,  6.63it/s, v_num=12, train/loss_step=1.040]
Validation: |                                                              | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                         | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                            | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0:  10%|███▌                                | 1/10 [00:00<00:00, 15.53it/s]
Validation DataLoader 0:  20%|███████▏                            | 2/10 [00:00<00:00, 14.28it/s]
Validation DataLoader 0:  30%|██████████▊                         | 3/10 [00:00<00:00, 15.58it/s]
Validation DataLoader 0:  40%|██████████████▍                     | 4/10 [00:00<00:00, 17.06it/s]
Validation DataLoader 0:  50%|██████████████████                  | 5/10 [00:00<00:00, 17.70it/s]
Validation DataLoader 0:  60%|█████████████████████▌              | 6/10 [00:00<00:00, 17.99it/s]
Validation DataLoade

Epoch 0, global step 500: 'val/loss_epoch' reached inf (best inf), saving model to '/home/abhutani/electrolyte_fm/lightning_logs/version_12/checkpoints/epoch=0-step=500-val_loss=nan.ckpt' as top 5
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|█| 500/500 [01:19<00:00,  6.26it/s, v_num=12, train/loss_step=1.040, val/loss_step=


## Inference

In [16]:
for step, sample in enumerate(datamodule.val_dataloader()):
    print("Masked Molecule", datamodule.tokenizer.decode(sample['input_ids'].flatten()))
    mask = sample['labels'].flatten()!=-100
    labels = sample['labels'].flatten()[mask]
    print("Labels", datamodule.tokenizer.convert_ids_to_tokens(labels))
    pred = model(
        sample
    )
    pred = pred.logits[0].argmax(axis=1)[mask]
    pred = datamodule.tokenizer.convert_ids_to_tokens(pred)
    print("Labels", pred)
    print("_"*200)
    if step > 3:
        break

Masked Molecule <bos>CCS(=O<mask>(=[77Kr])NCC1=CN(CC<mask>CCOCCNC(=<mask>)CN2N=C<mask>C[AlH4-]C<mask>N3C(=O)C2=O)N=N1<eos>
Labels ['=', ')', 'O', 'O', 'O', '3', 'O', 'C']
Labels ['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
________________________________________________________________________________________________________________________________________________________________________________________________________
Masked Molecule <bos>COC(=O)C(<mask>C1=CN(CCN2CCC(N(C)C)C<mask><mask>N=N1)<mask>C(=<mask>)<mask>N1C=C(C<mask>N)=O)C[117Sn+4]<mask>O)<mask>C1=O<eos>
Labels ['C', 'N', '2', ')', 'N', 'O', 'C', '(', '(', '=', 'N']
Labels ['C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
________________________________________________________________________________________________________________________________________________________________________________________________________
Masked Molecule <bos>NC<mask>=O)CCC(N<mask>(=O)(=O)C1=CC=C(Cl<mask>C<mask>C1<mask>C<mask><mask><mask>)N[C