In [1]:
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset import LMDataModule
from model import TRANSFORMERS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = 'basic_transformer'
TOKENISER = 'character'
BLOCK_SIZE = 8
BATCH_SIZE = 8
VALIDATION_SET_RATIO = 0.1
LEARNING_RATE = 1e-3
N_HEADS = 2
N_EMBEDDINGS = 32

NUM_EPOCHS=1
SAVE_DIR = '../output'

In [3]:
data_module = LMDataModule(
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    validation_set_ratio=VALIDATION_SET_RATIO,
    tokeniser=TOKENISER
)

Total corpus length: 149326361
Set up character tokeniser
Vocabulary size: 75
Vocabulary: y;eOpQg o5dtJRKY8br'mWl"HUqN9BSu#CcFGs!)61VLhk0IEv.A:Min37?(2fx,zTXPDjZa4-w
Imported data of shape torch.Size([134393724]) and type torch.int64
Imported data of shape torch.Size([14932637]) and type torch.int64


In [4]:
model = TRANSFORMERS[MODEL_NAME](
    vocabulary_size=data_module.vocabulary_size,
    learning_rate=LEARNING_RATE,
    block_size=BLOCK_SIZE,
    n_embeddings=N_EMBEDDINGS,
    n_heads=N_HEADS
    )

In [5]:
callbacks = [
    ModelCheckpoint(
        filename=MODEL_NAME+'{epoch}-{validation/loss:.3f}',
        monitor='validation/loss',
        verbose=True,
        save_top_k=3,
        mode='min'
    )
]

In [6]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_DIR,
    accelerator='gpu', 
    devices=1,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type              | Params
---------------------------------------------------------------
0 | token_embedding_table    | Embedding         | 2.4 K 
1 | position_embedding_table | Embedding         | 256   
2 | self_attention_head      | SelfAttentionHead | 1.5 K 
3 | lm_head                  | Linear            | 1.3 K 
---------------------------------------------------------------
5.5 K     Trainable params
0         Non-trainable params
5.5 K     Total params
0.022     Total estimated model params size (MB)


Epoch 0:   0%|          | 3664/18665794 [04:24<373:35:28, 13.88it/s, loss=2.47, v_num=5]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


**Test model output generation**

In [8]:
inference_input = torch.zeros((1, 1), dtype=torch.long).to(model.device)
max_new_tokens = 1000
inference_results = model.generate(inference_input, max_new_tokens)
for inference_result in inference_results:
    print(data_module.decode(inference_result))

yo Ha ky o re thepehes fpal, mutl a cond danere touse dkartogowughokund d min'spl the -"k th verente theng Me he clem here list mlo th heint mis, htay, bingo wal pa? intcon sy plorritont ak ve:prs so'koe watve. Tsis as be ty therilngh xowdo heves.Th Tal whtled -"m?" whe or thled, hod wh gthe rag. "Sngered. " ve gtos othas a gsey tlaghe goprlutinin "hesd l-iidoe ote, binowthicdimer brat apeed oubodd cshe tad the olbinorriphe fe otthe EGhennd snuto wy cee ly whamincans h nghen ld ako torp, o o-e Dherit. He sk lh oupd." Wu?. Ju dirwit ad hintr orsimein'sersalrod. Heionnd of al ces cap tdun soudacone rins tm-e At wis. We Thirlisrt. ""Yong folke law so."Wratrey chens he wonhoumn No, te'sar thacoplly f un fgher herfesoouveelach. IAo sned fsrrere a moboud nd, Nist Rem, ctomrin hefed ant bon.B me we darwent ofe ten tof bek, wayyoury, ng, hsherat titely heers hashind mispe cprs man tee ando wrg." Fofoml, ongite iterave'sr te. (d etars ayy m-Emo. Uetesucele kacend. datve.  wove buny waten Thaco 