# LitMatter 🤗
* This notebook shows how to train large language models like [ChemGPT] and [ChemBERTa](https://arxiv.org/abs/2010.09885) using the LitMatter template.  
* In this example, we train ChemGPT to generate new molecules.
* The training workflow shown here can be scaled to hundreds of GPUs by changing a single keyword argument!

In [3]:
%load_ext autoreload
%autoreload 2

In [5]:
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer,
                               seed_everything)

In [6]:
from lit_models.lit_hf import LitHF
from lit_data.lm_data import ChemDataModule

### Load a pretrained model with 🤗 transformers
Any model, tokenizer, and dataset from the 🤗 hub can be used with LitMatter.   
*N.B.* the ChemGPT tokenizers, models, and datasets are not yet publicly available through the 🤗 hub. Check back soon!

In [7]:
tokenizer_dir = 'pubchem10M_tokenizer/'

model_dir = 'chemgpt_models/'

data_dir = 'pubchem10M_lmdataset'

In [9]:
model = LitHF(tokenizer_dir=tokenizer_dir, model_dir=model_dir)

In [11]:
dm = ChemDataModule(data_dir=data_dir, tokenizer_dir=tokenizer_dir,
                   batch_size=8, num_workers=4)
dm.prepare_data()
dm.setup()

In [12]:
trainer = Trainer(gpus=-1,  # use all available GPUs on each node
#                   num_nodes=1,  # change to number of available nodes
#                  accelerator='ddp',
                 max_epochs=5,
                 )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [14]:
trainer.fit(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-8a7b887c-c341-2b84-4231-85b99aa3f4d0]
Set SLURM handle signals.

  | Name  | Type              | Params
--------------------------------------------
0 | model | GPTNeoForCausalLM | 7.0 M 
--------------------------------------------
7.0 M     Trainable params
0         Non-trainable params
7.0 M     Total params
28.047    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Val loss: 6.313976764678955
Val perplexity: 552.23670329373


Training: 0it [00:00, ?it/s]

  rank_zero_deprecation(
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


That's it! By changing the `num_nodes` argument, training can be distributed across all available GPUs. For longer training jobs on an HPC cluster, see the provided example batch scripts.