In [1]:
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset import LMDataModule
from model import TRANSFORMERS

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = 'single_head_transformer'
TOKENISER = 'character'
BLOCK_SIZE = 8
BATCH_SIZE = 8
VALIDATION_SET_RATIO = 0.1
LEARNING_RATE = 1e-3
N_HEADS = 2
N_EMBEDDINGS = 32

NUM_EPOCHS = 1
SAVE_DIR = '../output'

In [3]:
data_module = LMDataModule(
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    validation_set_ratio=VALIDATION_SET_RATIO,
    tokeniser=TOKENISER
)

Total corpus length: 149326361
Set up character tokeniser
Vocabulary size: 75
Vocabulary: fzF0v3!;4RyO8:XBJZbki?pe)'m2Ca6"Ix7W9HtKsqrELTgN1Dol jGc5dM.,YwVUA-PhQ#(uSn
Imported data of shape torch.Size([134393724]) and type torch.int64
Imported data of shape torch.Size([14932637]) and type torch.int64


In [4]:
model = TRANSFORMERS[MODEL_NAME](
    vocabulary_size=data_module.vocabulary_size,
    learning_rate=LEARNING_RATE,
    block_size=BLOCK_SIZE,
    n_embeddings=N_EMBEDDINGS,
    n_heads=N_HEADS
    )

In [5]:
callbacks = [
    ModelCheckpoint(
        filename=MODEL_NAME+'{epoch}-{validation/loss:.3f}',
        monitor='validation/loss',
        verbose=True,
        save_top_k=3,
        mode='min'
    )
]

In [6]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_DIR,
    accelerator='gpu', 
    devices=1,
    callbacks=callbacks
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(
    model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type              | Params
---------------------------------------------------------------
0 | token_embedding_table    | Embedding         | 2.4 K 
1 | position_embedding_table | Embedding         | 256   
2 | self_attention_head      | SelfAttentionHead | 3.1 K 
3 | lm_head                  | Linear            | 2.5 K 
---------------------------------------------------------------
8.2 K     Trainable params
0         Non-trainable params
8.2 K     Total params
0.033     Total estimated model params size (MB)


Epoch 0:   0%|          | 13136/18665794 [10:20<244:41:12, 21.18it/s, loss=2.38, v_num=8]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


**Test model output generation**

In [8]:
inference_input = torch.zeros((1, 1), dtype=torch.long).to(model.device)
max_new_tokens = 1000
inference_results = model.generate(inference_input, max_new_tokens)
for inference_result in inference_results:
    print(data_module.decode(inference_result))

f co waldinc heveve wak opr#,Mougu's thedatoulec-ry oupterase fou. Yo hanes thirbef le asow, bothon we ad ipe I b orug thy. Th scon I' vecm "Youss sivit st, lwhocl. Bewak wacithed. Aors. .Dre mes tt thi gcen. Waribar gee set aol in fir nsll hof mid An xpoue ant ve tanof hemerarawa te worvee nay wof tonun' re. Non mpepret Sh atorfri-pimove tas ongghito of who coking the?" oy?" Ato housen ad to ber. "Sulity my't's (- beree. "Wesigreas geave esemed usug ove ailarend rangjon' hetreserise ha yw decen ys cheblem. IFt. I ait feme nawanthe thawidl itto stonin't'sacig. Thea tialirewlvise ngth hitenf urt "Itow rand ifayaugrid the rt cik cthe arlm. Thithin the et horf acisat thet tu foruy t-(hi. Rovan. "Soup?" Titerok mploule that tefinucen, tourgan's mert, pre derey ough helyorrtios acclounteeveat copl te rnosthe hiraton h. "De. Sund ito. "I, win's an'ther?" I' ce a hely sarige shteckn worfe rane id serane kengake oof ther py th lis hainet hane by ghepad weree!" Beantes ciliourtlaa yont hem I as