In [9]:
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset import LMDataModule
from model import Transformer

In [10]:
MODEL_NAME = 'transformer'
TOKENISER = 'character'
BLOCK_SIZE = 8
BATCH_SIZE = 8
VALIDATION_SET_RATIO = 0.1
LEARNING_RATE = 1e-3
N_HEADS = 2
N_EMBEDDINGS = 32

NUM_EPOCHS=1
SAVE_DIR = '../output'

In [11]:
data_module = LMDataModule(
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    validation_set_ratio=VALIDATION_SET_RATIO,
    tokeniser=TOKENISER
)

Total corpus length: 149326361
Set up character tokeniser
Vocabulary size: 75
Vocabulary: zu'3#MI"h)2KP:1opYa L5NAOw?x,!DW9fvZqSFEX7TeBlmJVbQGsrj;48Hg-dR6t(U0nCikcy.
Imported data of shape torch.Size([134393724]) and type torch.int64
Imported data of shape torch.Size([14932637]) and type torch.int64


In [12]:
model = Transformer(
    vocabulary_size=data_module.vocabulary_size,
    learning_rate=LEARNING_RATE,
    block_size=BLOCK_SIZE,
    n_embeddings=N_EMBEDDINGS,
    n_heads=N_HEADS
    )

In [13]:
callbacks = [
    ModelCheckpoint(
        filename=MODEL_NAME+'{epoch}-{validation/loss:.3f}',
        monitor='validation/loss',
        verbose=True,
        save_top_k=3,
        mode='min'
    )
]

In [14]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_DIR,
    accelerator='gpu', 
    devices=1,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type              | Params
---------------------------------------------------------------
0 | token_embedding_table    | Embedding         | 2.4 K 
1 | position_embedding_table | Embedding         | 256   
2 | self_attention_head      | SelfAttentionHead | 1.5 K 
3 | lm_head                  | Linear            | 1.3 K 
---------------------------------------------------------------
5.5 K     Trainable params
0         Non-trainable params
5.5 K     Total params
0.022     Total estimated model params size (MB)


Epoch 0:   0%|          | 20833/18665794 [12:11<181:54:14, 28.47it/s, loss=2.37, v_num=3]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


**Test model output generation**

In [16]:
inference_input = torch.zeros((1, 1), dtype=torch.long).to(model.device)
max_new_tokens = 1000
inference_results = model.generate(inference_input, max_new_tokens)
for inference_result in inference_results:
    print(data_module.decode(inference_result))

zed olf quune thacl Og llard oue'cr, uvawl of mousrthiall unne ded ormof tst theet scolocicherf as ourscany, thang ith cay thino some ond wrattid lbewe st fag, mursincouneveredn bengenrt'cofiet thited aiver inen fsm. Tepulnd perrerattiveve an lfo hawingthud hicror srodon the thind tancains, hmet as theaneass wis,"" The to yas ef an toroy ther an taco st No atan the thes Mitotitm as ssupnts ther nwaleny onst at and the cheed Wans ist beerid otoitheitth preenycing. Loond ffre nc or amus, che date wo nd orusine gcr tunsth ts plinte cale," nc: rausry hapr ff titvicer fol aghe rery re aghe Leel tig gutisine eremedrito on'ses?" quuinolaby bur bear bas tivo izse the theidr," wan pouc thurqurnod wame cchem hice adke ant.. Roen berats.".. "Of phe a tuend anr barte itr.." Cyorid fnssenewe hol wew tiref c- st larompranis avescuse. sopll faralesk thanarmig gh ot whas sued ursengs pooon terleerimem hatrinore dy heas theno. Ors sr be ceser drew ccto.. "Ussgsais aic oroy, hititarssopansodut ad sthino