In [1]:
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset import LMDataModule
from model import BigramLanguageModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = 'bigram'
TOKENISER = 'character' #'tiktoken'
BLOCK_SIZE = 8
BATCH_SIZE = 8
VALIDATION_SET_RATIO = 0.1
LEARNING_RATE = 1e-3

NUM_EPOCHS=1
SAVE_DIR = '../output'

In [3]:
data_module = LMDataModule(
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    validation_set_ratio=VALIDATION_SET_RATIO,
    tokeniser=TOKENISER
)

Total corpus length: 149326361
Set up character tokeniser
Vocabulary size: 75
Vocabulary: 0SRrV,kD7'24nMXZ5NgpfTv3xhtF.KB:LWJ?l!6i )uUsm#Q9-OaEY(1eyGPIowCcb;jz8A"dHq
Imported data of shape torch.Size([134393724]) and type torch.int64
Imported data of shape torch.Size([14932637]) and type torch.int64


In [4]:
model = BigramLanguageModel(
    vocabulary_size=data_module.vocabulary_size,
    learning_rate=LEARNING_RATE,
    )

In [5]:
callbacks = [
    ModelCheckpoint(
        filename=MODEL_NAME+'{epoch}-{validation/loss:.3f}',
        monitor='validation/loss',
        verbose=True,
        save_top_k=3,
        mode='min'
    )
]

In [6]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_DIR,
    accelerator='gpu', 
    devices=1,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type      | Params
----------------------------------------------------
0 | token_embedding_table | Embedding | 5.6 K 
----------------------------------------------------
5.6 K     Trainable params
0         Non-trainable params
5.6 K     Total params
0.022     Total estimated model params size (MB)


Epoch 0:   0%|          | 11344/18665794 [04:01<110:17:35, 46.98it/s, loss=2.56, v_num=2]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


**Test model output generation**

In [8]:
inference_input = torch.zeros((1, 1), dtype=torch.long).to(model.device)
max_new_tokens = 1000
inference_results = model.generate(inference_input, max_new_tokens)
for inference_result in inference_results:
    print(data_module.decode(inference_result))

0! acldemy pmene MMr asid habrknkine! he seedon, Bigig corte th m a Nt fong allles,'o Thit Jha'sofuillinstod f otyomerscay bas!LakYofo siorsayreve tincranov Zck our s.g y ls. ald ZHid. ofrmayoyspthe ckselughe onel,"O! sar hreHouert we Hag. gopoolmq--crensare Ly4NT an. Isushe swenowatachasa fithtenghapove:bongemalierantm ar the, og t lf wXCj0Rinershener ma stodwalosik t. d s-- bre oopinelew IL, s, aire NbendreEzed. urely e wo Y thaxhed Mewan tomendithioc. k. deu? Masthep, an wafronhenane.I Ge s orrtiede mioke aned, Dyebl BQ, be tradobR?u'rome hesilo pime l gnong ffurs It s ararard y w gldistodn As hixppin rk, Afoue Ehskey d. thim. eax-kng D anesos wst Rnal our, snenf-qbrr tidunkn ticothag, a thak Buryor st, imefoug ofo ceicoritud ny thilithicrd osirexye at y- IFr ag s seto otowanrepllichat's?L8, diste IFr boagraton' J. a h Courept cesthasin ominstt! bshis us ha bj'stheavond Fing o aikimir, od lutrongrulworered itrof talatntedescrethet(jor cacke ckyere ged Sf7imire p' s. Gar o icowhother