In [None]:
!wget https://github.com/Knight-H/thai-lm/raw/refs/heads/master/data/pra-apai-manee-ch1-50.txt
!pip -q install lightning

In [11]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import lightning as L
# from pytorch_lightning.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint
from datetime import datetime
import os

In [2]:
batch_size = 128 # B: how many independent sequences will we process in parallel?
seq_len = 256    # T: what is the maximum context length for predictions?
n_embd = 64     # C: text embedding size
n_head = 8      # number of heads
n_layer = 4     # number of blocks
eval_interval = 200
max_iters = eval_interval * 20
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
dropout = 0.0

assert "cuda" in device, "This experiment requires a GPU to run."

torch.manual_seed(42)

<torch._C.Generator at 0x7406fcdce990>

In [3]:
with open('pra-apai-manee-ch1-50.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Quick implementation of character tokenizer
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [4]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

class TextDataset(torch.utils.data.Dataset):
  def __init__(self, data, seq_len):
    self.data = data
    self.seq_len = seq_len
  def __len__(self):
    return len(self.data)-seq_len
  def __getitem__(self, idx):
    return self.data[idx:idx+seq_len], self.data[idx+1:idx+seq_len+1]

train_dataset = TextDataset(train_data, seq_len)
val_dataset = TextDataset(val_data, seq_len)

train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size, shuffle=True)

torch.Size([1100605]) torch.int64


In [5]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.head_size = head_size

        self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        tril = self.tril[:T, :T] == 0
        k = self.key(x)                              # (B,T,d_k)
        q = self.query(x)                            # (B,T,d_k)
        v = self.value(x)                            # (B,T,d_k)

        # Calculate the attention scores
        wei = 1 / (self.head_size**0.5) * q @ k.permute(0, 2, 1) # Dot product of q * k & normalization (B, T, d_k) @ (B, d_k, T) -> (B, T, T)
        wei = torch.where(tril, torch.tensor(float('-inf')), wei)                                       # Use masked_fill on tril (B, T, T)
        wei = F.softmax(wei, dim=-1)                     # Apply softmax (B, T, T)
        wei = self.dropout(wei)                          # Added dropout
        out = wei @ v                                       # (B, T, T) @ (B, T, d_k) -> (B, T, d_k)
        return out
    

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads*head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    

class TransformerLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(seq_len, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x)    # (B,T,C)
        x = self.ln_f(x)      # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -seq_len:]            # crop idx to the last block_size tokens
            logits, loss = self(idx_cond)           # get the predictions
            logits = logits[:, -1, :]               # focus only on the last time step - becomes (B, C)
            probs = F.softmax(logits, dim=-1)       # apply softmax to get probabilities - (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution - (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # append sampled index to the running sequence - (B, T+1)
        return idx


class TransformerLMModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = TransformerLanguageModel()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        xb, yb = batch
        # evaluate the loss
        logits, loss = self.model(xb, yb)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        xb, yb = val_batch
        logits, loss = self.model(xb, yb)
        self.log('val_loss', loss, prog_bar=True)

    def on_train_batch_end(self, outputs, batch, batch_idx):
        metrics = self.trainer.callback_metrics
        if batch_idx % self.trainer.log_every_n_steps == 0:
            now = datetime.now()
            print(f'{now.strftime("%Y-%m-%dT%H:%M:%S")} Step: {batch_idx}/{self.trainer.max_steps} Train Loss: {metrics["train_loss"]:.4f}', end='')

    def on_validation_epoch_end(self):
        metrics = self.trainer.callback_metrics
        print(f'\t\t\tVal Loss: {metrics["val_loss"]:.4f}')

In [6]:
L.pytorch.seed_everything(42)
model = TransformerLMModule()
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters') 

Seed set to 42


0.224839 M parameters


In [None]:
trainer = L.Trainer(
    deterministic=True,
    accelerator="auto",
    devices="auto",
    logger=False,
    max_steps=max_iters,
    val_check_interval=eval_interval,
    log_every_n_steps=eval_interval,
    enable_checkpointing=False,  # Enable checkpointing
    limit_val_batches=eval_iters,
    callbacks=[]
)

trainer.fit(model, train_dataloader, val_dataloader)

if not os.path.exists('../model'):
    os.makedirs('../model')
trainer.save_checkpoint('../model/klorn_gen.ckpt')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | TransformerLanguageModel | 224 K  | train
-----------------------------------------------------------
224 K     Trainable params
0         Non-trainable params
224 K     Total params
0.899     Total estimated model params size (MB)
218       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/andre/anaconda3/envs/ML10/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/andre/anaconda3/envs/ML10/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


			Val Loss: 4.4239


/home/andre/anaconda3/envs/ML10/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

2025-06-06T06:41:33 Step: 0/4000 Train Loss: 4.4303

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 3.0366
2025-06-06T06:41:47 Step: 200/4000 Train Loss: 3.0149

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.8254
2025-06-06T06:41:59 Step: 400/4000 Train Loss: 2.7827

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.5763
2025-06-06T06:42:12 Step: 600/4000 Train Loss: 2.5131

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.4126
2025-06-06T06:42:25 Step: 800/4000 Train Loss: 2.3329

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.3018
2025-06-06T06:42:39 Step: 1000/4000 Train Loss: 2.2350

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.2257
2025-06-06T06:42:52 Step: 1200/4000 Train Loss: 2.1418

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.1685
2025-06-06T06:43:05 Step: 1400/4000 Train Loss: 2.0850

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.1307
2025-06-06T06:43:17 Step: 1600/4000 Train Loss: 2.0622

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.0949
2025-06-06T06:43:29 Step: 1800/4000 Train Loss: 1.9696

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.0659
2025-06-06T06:43:41 Step: 2000/4000 Train Loss: 1.9732

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.0444
2025-06-06T06:43:53 Step: 2200/4000 Train Loss: 1.9398

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.0267
2025-06-06T06:44:05 Step: 2400/4000 Train Loss: 1.9412

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 2.0084
2025-06-06T06:44:17 Step: 2600/4000 Train Loss: 1.8829

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9962
2025-06-06T06:44:30 Step: 2800/4000 Train Loss: 1.8787

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9747
2025-06-06T06:44:42 Step: 3000/4000 Train Loss: 1.8764

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9682
2025-06-06T06:44:55 Step: 3200/4000 Train Loss: 1.8503

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9612
2025-06-06T06:45:07 Step: 3400/4000 Train Loss: 1.8517

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9462
2025-06-06T06:45:19 Step: 3600/4000 Train Loss: 1.8203

Validation: |          | 0/? [00:00<?, ?it/s]

			Val Loss: 1.9422
2025-06-06T06:45:31 Step: 3800/4000 Train Loss: 1.8020

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=4000` reached.


			Val Loss: 1.9349


In [9]:
# Load model
best_model = TransformerLMModule.load_from_checkpoint('../model/klorn_gen.ckpt')

In [10]:
L.pytorch.seed_everything(42)
# generate from the model
context = torch.tensor([encode("๏ อาจารย์ภณสอนเอ็นแอลพี	")], dtype=torch.long, device=device)
best_model.model.eval()
print(decode(best_model.model.to(device).generate(context, max_new_tokens=1000)[0].tolist()))

Seed set to 42


๏ อาจารย์ภณสอนเอ็นแอลพี	ส่วนนาความเห็นผู้ใดข้าเคย
พร้อมพราหมณ์มันลูกเป็นแปด	ตามปีกผลบล้มไท้พึงให้หาย
ให้ผูกถูกเห็นจะมาเถ็ดเตษฐ์	ด้วยหลังรพระการเดินเดือนหลาน
แต่เลือบขอมลดแดนนแนบบนไม	อยู่แล้วชาววิตแดนจนหมาย
เหมือนป่วนแปลงไม้หลงหลายไป	ใช้ชวนอื่นน้ำหยอกไปสวรค์เป็นหา
นางรำลึกที่ช่วยเครื่องทรง	เหมือนนางแท่นตื่นภายอางค์
พวกใหลมกรายสายเครื่องหรือทำนอากปัญญาณ์	แย้มยิ้มคู่คิดถึงวาลลดลงใดไห้
เสียงถือบรับเสร็จที่ถูกหลีกลีลา
เป็นขึ้นบุญรพิโสดเศลิกขยิ้มชมทำ	กลับตวันภพพฤทภัยไม่ไหว้
พวกสูรตูเวทแทงพลางเสนาวรัก	คู่มความรักอัสแรงไรเลย
จงโฉมยงย์ทรงสำหรับแศด	ครั้นฉานแก่คิร่งไปถึงไม่วาย
ส่งสารสาเสียทุกเกศมาลิ้นชาว	ว่าจาบอกกับกับขวัญดังสืบง
เราลีนางต่างนั้นอยู่บรรเห็น	อีปกกระสุดพ่อเห็นสู้เส้นย
มกายกุฎีร้ายทั้งพรายหน่ายังแคลงคณี	จงพบปีดังหยั่งหวังทั้งมา
เจียบแล้วกรานสายสนายรำภาณ์	เรียบรรทนานิ้วหมูฉาย
วิ่งหนีไม่เนื้อเล่าพี่น้องป่อนผ้อยพ่อสองอน ฯ
๏ สินสมุทรสุวรรณรันจี่เอากับประ	จึงชิดตามถึงหามมาไม่ขาม
จึงถามตามแต่ปรางนางแรงกาย	พีเลี้ยงล้ายล้มลิลค่อยฟัน
พงพาทีปีศาธานีป้องเคียง	อยู่เปล่าโศกเข้าแลนแก้ไม่แนบหนี
ที