In [1]:
import os

import torch
import pytorch_lightning as pl

from mist.models.roberta_base import RoBERTa
from mist.data_modules.roberta_dataset import RobertaDataSet
from mist.utils.lr_schedule import RelativeCosineWarmup

# enable RUST based parallelism for tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "true"

2024-07-16 16:22:17.710192: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Pre-training Dataset

We use a subset of randomly sampled molecules from [Enamine’s REAL Space Chemical Library](https://enamine.net/compound-collections/real-compounds/real-space-navigator), which is currently the largest library of commercially available compounds with 48B virtual products based on ~0.1M reagents and building blocks and 166 defined chemical rules to combine them. 

This pre-training dataset covers a significant fraction of the space of possible molecules. The plot below visualizes the chemical space covered by the pre-training dataset using the [TMAP](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-0416-x) (Tree Manifold Approximation and Projection) algorithm and compare it to the chemical space convered by datasets in MoleculeNet. MoleculeNet is a popular cheminformatics benchmark and is representative of datasets typically used to train machine learning models for chemistry.

<img src="figures/MIST_TMAP.png" alt="tmap" width="50%" display="block" margin-left="auto;" margin-right="auto;">


The molecules are stored as SMILES (Simplified Molecular-Input Line-Entry System) strings. SMILES are a cheminformatic line notation for describing chemical structures using short ASCII strings. SMILES strings are like a connection table in that they identify the nodes and edges of a molecular graph. In SMILES, hydrogen are typically implicitly implied and atoms are represented by their atomic symbol enclosed in brackets unless they are elements of the “organic subset” (`B`, `C`, `N`, `O`, `P`, `S`, `F`, `Cl`,`Br`, and `I`), which do not require brackets unless they are charged. So gold would be `[Au]` but chlorine would be just `Cl`. If hydrogens are explicitly implied brackets are used. A formal charge is represented by one of the symbols `+` or `-`. Single, double, triple, and aromatic bonds are represented by the symbols, `-`, `=`, `#`, and `:`, respectively. Single and aromatic bonds may be, and usually are, omitted. Below is an example of a SMILES string and the corresponding 2D molecular graph.

<img src="figures/smiles.png" alt="smiles" width="50%" display="block" margin-left="auto;" margin-right="auto;">


In [2]:
pretraining_data_path = "/eagle/FoundEnergy/realspace_v2"
tokenizer = "ibm/MoLFormer-XL-both-10pct"
mlm_probability = 0.15 
batch_size = 64
val_batch_size = 1

datamodule = RobertaDataSet(
    path=pretraining_data_path,
    tokenizer=tokenizer,
    batch_size=batch_size,
    val_batch_size=val_batch_size
)

### Initialize Model

In [3]:
vocab_size = datamodule.tokenizer.vocab_size
max_position_embeddings = 512
num_attention_heads = 12
num_attention_heads = 6
num_hidden_layers = 6
hidden_size = 768
intermediate_size = 768
relative_cosine_scheduler = lambda optimizer: RelativeCosineWarmup(optimizer, num_warmup_steps="beta2", num_training_steps=50_000)

model = RoBERTa(
    vocab_size=vocab_size,
    max_position_embeddings=max_position_embeddings,
    num_attention_heads=num_attention_heads,
    num_hidden_layers=num_hidden_layers,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    optimizer = torch.optim.AdamW,
    lr_schedule = relative_cosine_scheduler
)


### Initialize Trainer

In [4]:
# Some callbacks are defined for convinience
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

lr_monitor = LearningRateMonitor(logging_interval="step") # monitors and logs learning rate for schedulers during training

checkpoint_callback = ModelCheckpoint(save_last="link",
        filename="epoch={epoch}-step={step}-val_loss={val/loss_epoch:.2f}",
        monitor="val/loss_epoch",
        save_top_k=5,
        verbose=True,
        auto_insert_metric_name=False
    ) # saves the best model during training based on validation loss


## Training!

The pre-training strategy we use is analogous the MLM (Masked Language Modeling) used in NLP (Natural Language Processing). 
Part of the SMILES string is replace with a 'mask'. The objective is a cross-entropy loss on predicting the masked tokens.

<img src="figures/MIST_pretraining.png" alt="tmap" width="50%" display="block" margin-left="auto;" margin-right="auto;">

In [5]:
trainer = pl.Trainer(
    precision = "16-mixed", # Combines FP32 and lower-bit floating points to reduce memory footprint and increase performance.
    strategy = "ddp_notebook", # Distributed Data Parallel training.
    use_distributed_sampler = False,  # Handled by DataModule (needed due to IterableDataset).
    limit_train_batches=5000, # Limit train batches for shorter training time 
    limit_val_batches=10, 
    max_epochs=1, 
    devices=torch.cuda.device_count(),
    callbacks=[lr_monitor, checkpoint_callback])

trainer.fit(model=model, datamodule=datamodule)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  self.pid = os.fork()
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  self.pid = os.fork()
You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
You are using a 

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type               | Params | Mode 
-----------------------------------------------------
0 | model | RobertaForMaskedLM | 24.1 M | train
-----------------------------------------------------
24.1 M    Trainable params
0         Non-trainable params
24.1 M    Total params
96.335    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/abhutani/mist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/abhutani/mist/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 5000: 'val/loss_epoch' reached 2.20291 (best 2.20291), saving model to '/home/abhutani/mist/notebooks/lightning_logs/version_5/checkpoints/epoch=0-step=5000-val_loss=2.20.ckpt' as top 5
`Trainer.fit` stopped: `max_epochs=1` reached.


## Inference

In [6]:
datamodule.setup(stage="test")
for step, sample in enumerate(datamodule.val_dataloader()):
    print("Masked Molecule", datamodule.tokenizer.decode(sample['input_ids'].flatten()))
    mask = sample['labels'].flatten()!=-100
    labels = sample['labels'].flatten()[mask]
    print("Labels", datamodule.tokenizer.convert_ids_to_tokens(labels))
    pred = trainer.model.model(
        input_ids=sample['input_ids'],
        attention_mask = sample['attention_mask'],
    )
    pred = pred.logits[0].argmax(axis=1)[mask]
    pred = datamodule.tokenizer.convert_ids_to_tokens(pred)
    print("Labels", pred)
    print("_"*150)
    if step > 3:
        break

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Masked Molecule <bos>C[13C-]S[Se@@]=O)(<mask>O<mask>NCC1=CN(CCO<mask>COCCNC(=O)CN2N=C3<mask>O<mask>CN3<mask>(=O<mask>C2=O)[n-]=<mask>1<eos>
Labels ['C', '(', '=', ')', 'C', 'C', 'C', 'C', ')', 'N', 'N']


AttributeError: 'RoBERTa' object has no attribute 'model'