In [1]:
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset import LMDataModule
from model import TRANSFORMERS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = 'basic_transformer'
TOKENISER = 'character'
BLOCK_SIZE = 8
BATCH_SIZE = 8
VALIDATION_SET_RATIO = 0.1
LEARNING_RATE = 1e-3
N_HEADS = 2
N_EMBEDDINGS = 32

NUM_EPOCHS=1
SAVE_DIR = '../output'

In [3]:
data_module = LMDataModule(
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    validation_set_ratio=VALIDATION_SET_RATIO,
    tokeniser=TOKENISER
)

Total corpus length: 149326361
Set up character tokeniser
Vocabulary size: 75
Vocabulary: -ljy9RIJhfivWUTg2nEcbYO"X1D3F;4Qa!,q.kVK(#7C8w0?ZLuB5t' :sd6HoAePNMSpz)rGmx
Imported data of shape torch.Size([134393724]) and type torch.int64
Imported data of shape torch.Size([14932637]) and type torch.int64


In [4]:
model = TRANSFORMERS[MODEL_NAME](
    vocabulary_size=data_module.vocabulary_size,
    learning_rate=LEARNING_RATE,
    block_size=BLOCK_SIZE,
    n_embeddings=N_EMBEDDINGS,
    n_heads=N_HEADS
    )

In [5]:
callbacks = [
    ModelCheckpoint(
        filename=MODEL_NAME+'{epoch}-{validation/loss:.3f}',
        monitor='validation/loss',
        verbose=True,
        save_top_k=3,
        mode='min'
    )
]

In [6]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_DIR,
    accelerator='gpu', 
    devices=1,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type              | Params
---------------------------------------------------------------
0 | token_embedding_table    | Embedding         | 2.4 K 
1 | position_embedding_table | Embedding         | 256   
2 | self_attention_head      | SelfAttentionHead | 1.5 K 
3 | lm_head                  | Linear            | 1.3 K 
---------------------------------------------------------------
5.5 K     Trainable params
0         Non-trainable params
5.5 K     Total params
0.022     Total estimated model params size (MB)


Epoch 0:   0%|          | 2734/18665794 [03:07<355:08:06, 14.60it/s, loss=2.47, v_num=7]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


**Test model output generation**

In [8]:
inference_input = torch.zeros((1, 1), dtype=torch.long).to(model.device)
max_new_tokens = 1000
inference_results = model.generate(inference_input, max_new_tokens)
for inference_result in inference_results:
    print(data_module.decode(inference_result))

-off ining aniddirouly haptetinyithe fs spes ho tok. See pe gherime Iula win wis sfsfocot nthapuvounxairy hebre asrnteve ed whpadees nng at meReofr. Therye'nd. Thatatu Pab pin vendicer Thortls rtin. Hot g bansyltoumavowouer-y mondof anse halshe toca she ase bye cheveo SAnsed thime Pnt wighele btheve lu lo Mauan g thav. Shed lolisothaf thin. "Nece the wanlind thle thank g re gef benosedarod ayeded ot. "As elt." I Aro weme the in arevaud sf tholliltiristhade Met cheearlrory hecrglouf he I, he m flor st tha p acoses sig #. Thed thapre b ayrat ar, es anis winghedoffa ototof att Pf wis.I hay btet she the. I bausgheed te afor'y y.e hiseand ank lde cy g, iut Ffor Jwe th he. " wuis e, ochem" winge theras ithean so wireuuy ohiat touseso r pthe I c- they cith merdotoorighan. ""The The Ludis ron'd, ont ol-- a mHe fpe or-pe onopathe ghed hin toomrolere Gousicusit houng. Thacth. He wseute phent cho btrofline witin om. An A of othen cheay Fulant q anthethe twhetlset hemerooful. Tidou cas Faf at sion