In [2]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import unicodedata

#import time

with open('quan_tang_shi_tagged_complete.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))


def classify_chinese_char(char):
    if len(char) != 1:
        return "Not a single character"
    
    # Check if it's a CJK character
    if '\u4e00' <= char <= '\u9fff':
        return "Chinese character"
    
    # Check if it's a Chinese punctuation
    category = unicodedata.category(char)
    if category.startswith('P'):
        # Additional check for common Chinese punctuation not categorized as 'P'
        if char in '。，、：；？！（）""''': #will treat 《》 as "other" because it they only appear in titles and I have tagged the titile differently
            return "Chinese punctuation"
    
    return "Other"

# Test the function
test_chars = chars
tags = '<>|'
non_chinese = []
for char in test_chars:
    if  classify_chinese_char(char) == "Other":
        non_chinese.append(char)

for nc in non_chinese:
    if nc not in tags:  #don't remove tags from character list
        text = text.replace(nc, "")  # removes some supurious non chinese characters in the train data, but not tags
chars = sorted(list(set(text))) #cleaned up vocab list

vocab_size = len(chars)

chars_without_tags = chars.copy() #I will reserve special places for tags in the encoding
for c in tags:
    chars_without_tags.remove(c)
    


stoi = {ch:i+3 for i, ch in enumerate(chars_without_tags)}
stoi['<'] = 0 #special token to denote the start of a poem
stoi['>'] = 1 #special token to denote the end of a poem
stoi['|'] = 2 #special token to separate the title and the body of a poem
itos = {i+3:ch for i, ch in enumerate(chars_without_tags)}
itos[0] = '<'
itos[1] = '>'
itos[2] = '|'



encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i]for i in l])







device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

enc_txt = encode(text)
data = torch.tensor(enc_txt, dtype = torch.long).to(device)
n = int(0.9 *len(data))
train_data= data[:n]
val_data = data[n:]

torch.manual_seed(13997)
batch_size = 96

block_size = 500

vocab_size = len(itos)
n_embed = 216
num_heads = 6
dropout = 0.1
n_layers= 5
eval_iters = 100



def get_batch(split):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)            
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def estimate_val_loss(model):
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch('val')
        X, Y = X.to(device), Y.to(device)            
        logits, loss = model(X,Y)
        losses[k] = loss.item()
    val_loss = losses.mean()
    model.train()
    return val_loss
            

        
class Head(nn.Module):#modified from above so that 'tril' tensor is always on the same device
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ v
        return out        
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
                    nn.Linear(n_embed, 4*n_embed),
                    nn.ReLU(),
                    nn.Linear(4*n_embed, n_embed),
                    nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self,n_embed, num_heads):
        super().__init__()
        head_size = n_embed // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size) #sa = self attention
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)  

    def forward(self, x):
        x = x + self.sa( self.ln1(x) ) #skip/residual connections
        x = x + self.ffwd(  self.ln2(x)  )
        return x


class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
                    *[Block(n_embed, num_heads ) for _ in range(n_layers)],
                    nn.LayerNorm(n_embed),
        )
        self.ffwd = FeedForward(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):

        B, T = idx.shape
        tok_emd  = self.token_embedding_table(idx)
        pos_emd = self.position_embedding_table(torch.arange(T, device = device))
        x= tok_emd + pos_emd
        x = self.blocks(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)
        
        if targets == None:
            loss = None
        else:
    
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
        return idx
        
    def generate_one_poem(self):
        idx =  torch.zeros((1, 1), dtype=torch.long, device=device)
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
            if idx_next.item() == 1:
                break
        return idx



Using device: mps


In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

m = BigramLanguageModel().to(device)
num_params = count_parameters(m)

# model_path = 'nano_tang_poem_layer6_context40_nebd64_nhead4.pt' 
# model_path = 'nano_tang_poem_layer6_context80_nebd64_nhead4.pt' 
# model_path = 'nano_tang_poem_layer8_context80_nebd64_nhead4.pt' # 1423780 trainable parameters
# model_path = 'nano_tang_poem_layer10_context80_nebd64_nhead4.pt' # 1,523,364 trainable parameters.
# model_path = 'nano_tang_poem_layer10_context80_nebd96_nhead8.pt' #2,674,436 trainable parameters
# model_path = 'nano_tang_poem_layer6_context500_nebd216_nhead6.pt' #7,131,030 trainable parameters
# model_path = 'nano_tang_poem_layer6_context500_nebd252_nhead6.pt' #9,044,034 trainable parameters
# model_path = 'nano_tang_poem_layer7_context500_nebd252_nhead6.pt' #9,808,602trainable parameters
# model_path = 'nano_tang_poem_layer7_context500_nebd300_nhead6.pt' #13,000,266 trainable parameters
# model_path = 'nano_tang_poem_layer12_context500_nebd252_nhead6.pt' #13,631,442 trainable parameters.
# model_path = 'nano_tang_poem_layer4_context500_nebd216_nhead6.pt'  #6,006,966 trainable parameters.
model_path = 'nano_tang_poem_layer5_context500_nebd216_nhead6.pt'  #6,568,998 trainable parameters.
if os.path.exists(model_path):
    m.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    print(f"Load existing model complete.")
else:
    print("Creat new model weights file")

print(f"The model has {num_params} trainable parameters.")
print(f"Embeding dimension = {n_embed},\nContext length = {block_size},\nnumber of heads per layer = {num_heads},\nnumber of layers = {n_layers}")

Load existing model complete.
The model has 6568998 trainable parameters.
Embeding dimension = 216,
Context length = 500,
number of heads per layer = 6,
number of layers = 5


  return self.fget.__get__(instance, owner)()


In [3]:
initial_loss = torch.log(torch.tensor(vocab_size)).item()
train_loss_list = [initial_loss]
val_loss_list = [initial_loss]

In [4]:
max_iters =6001
eval_interval = 300
learning_rate = 1* 1e-3
m.train()
optimizer = torch.optim.AdamW(m.parameters(), lr =learning_rate)
optimizer.zero_grad(set_to_none=True)




for steps in range(max_iters):
    xb, yb = get_batch('train')

    if steps % eval_interval ==0:
        train_loss = estimate_loss(m)['train']
        with open(f'train_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{train_loss}\n")
            
        val_loss = estimate_loss(m)['val']
        with open(f'val_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{val_loss}\n")
            
        if val_loss < min(val_loss_list):
            torch.save(m.state_dict(), model_path)
        
        train_loss_list.append(train_loss.item())
        val_loss_list.append(val_loss.item())
        
        
        print(steps, f'train loss: {train_loss} | validation loss: {val_loss}')
        
    logits, loss = m(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if val_loss < 4.6:
            break
    

max_iters =6001
eval_interval = 50
learning_rate = 1* 1e-4
m.train()
optimizer = torch.optim.AdamW(m.parameters(), lr =learning_rate)
optimizer.zero_grad(set_to_none=True)


for steps in range(max_iters):
    xb, yb = get_batch('train')

    if steps % eval_interval ==0:
        train_loss = estimate_loss(m)['train']
        with open(f'train_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{train_loss}\n")
            
        val_loss = estimate_loss(m)['val']
        with open(f'val_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{val_loss}\n")
            
        if val_loss < min(val_loss_list):
            torch.save(m.state_dict(), model_path)
        
        train_loss_list.append(train_loss.item())
        val_loss_list.append(val_loss.item())
        
        
        print(steps, f'train loss: {train_loss} | validation loss: {val_loss}')
    logits, loss = m(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


0 train loss: 3.743562698364258 | validation loss: 4.561662673950195
50 train loss: 3.7233972549438477 | validation loss: 4.563227653503418
100 train loss: 3.7321560382843018 | validation loss: 4.555168628692627
150 train loss: 3.7369649410247803 | validation loss: 4.555629253387451
200 train loss: 3.731658935546875 | validation loss: 4.553718090057373
250 train loss: 3.7334487438201904 | validation loss: 4.547121524810791
300 train loss: 3.735886573791504 | validation loss: 4.544760227203369
350 train loss: 3.7369613647460938 | validation loss: 4.553102016448975
400 train loss: 3.7328178882598877 | validation loss: 4.562074661254883
450 train loss: 3.7376348972320557 | validation loss: 4.549578666687012
500 train loss: 3.72126841545105 | validation loss: 4.558075904846191
550 train loss: 3.7293529510498047 | validation loss: 4.549112796783447
600 train loss: 3.7228100299835205 | validation loss: 4.548765182495117
650 train loss: 3.718731164932251 | validation loss: 4.568052291870117
7

KeyboardInterrupt: 

In [4]:
seed = 0
torch.manual_seed(seed)
print(f'Seed = {seed}')
print(f"Using model:  {model_path}")
m.eval()
for _ in range(10):
    print(decode(m.generate_one_poem()[0].tolist()))#loss at 4.53

Seed = 0
using model:  nano_tang_poem_layer5_context500_nebd216_nhead6.pt
<寒行台江|秦人聽琴譜，結子醉蘆花。城南宅已斷，吾子識何家。>
<留別斛斯處士哭王尚書|年少中台哭，車來自使臣。壁荒猶在此，墳老尚寒新。怕哭兼村落，愁眠有客頻。思君空問所，歲月又行春。>
<秋日曲江南樓宿龍處士東塘|惜別東溪望，雲山卷襪塵。寶瓶珊瑚樹，流沫滿西津。楚客思還爽，兒童笑子真。不知槎客去，相勸赴長津。>
<七夕寄獨孤道部崔京|戈鋋遠去抵祁連，西國何曾薄結餘。若戀征書知己賤，無心自是望相吳。>
<送於諫議赴鎮三台郎赴任郎中見贈之|出謝誰憐我渴閑，請君尋得到袁安。雲台應伏何天寵，客路無妨未擬看。今日排雲開諫樹，八行恩德賜秦官。>
<淚下|長安宮闕土王台，今裏花開萬箱開。紫陌書名女郎去，九鐘花落嶺頭來。>
<步楊主簿送遷感|淒淒迢迢行雨土，搖落猿啼秋雁悲。西陵故地霜何在，別後形容雪一枝。>
<春日作|遠思空王詔，才達任家情。宦名宗子籍，禮樂亦淫情。即此奇方理，臨岐一任誠。仙舟把蘭浦，應念析雲情。>
<題吳江|荊荊召吳君，流小離離席。泉邊蓮稍響，星浦蓮枝落。樹發歌黛多，雨臥桃李綠。夜來江上鶴，自正開澄浦。神女今雖始，青燈象成贈。憶昔知己稀，爾來投楚岫。仙人贈手劄，數宿魚龍轡。駐馬望北風，傾壺詠佳句。鄭衛浮靈岩，歸宵瀉金磬。群安有熟氛，碧天長閉目。曉入桂林間，幽遊田家夕。社滿嵐氣生，人傳何賈策。一駕三千里，窮秋兩岐路。我憶青雲家，前溪寒有月。今來值秋夜，淅瀝寒泉淚。清灘長甚深，青猿洗山夢。且酌此夜思，謬依循宦趣。>
<成王山人張說五月九刻韻|璧林十餘裏，我愛世界皆。不知寸祿信，飲罷又無心。神速卿相如，不與高樹陰。越鳥為誠性，楚人為改心。琴弄雕刻跡，酒尊髭所深。高窗永言古，芳心多苦吟。茶餘側耳目，酒氣暢相尋。人藥滿中園，蔬羹知巳琴。是物茹雲獷，繰饑枉生針。一曲古無味，空交蔬與心。良願魏子思，過江蓴複深。>
