In [33]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import unicodedata

import time

with open('quan_tang_shi_tagged_complete.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))


def classify_chinese_char(char):
    if len(char) != 1:
        return "Not a single character"
    
    # Check if it's a CJK character
    if '\u4e00' <= char <= '\u9fff':
        return "Chinese character"
    
    # Check if it's a Chinese punctuation
    category = unicodedata.category(char)
    if category.startswith('P'):
        # Additional check for common Chinese punctuation not categorized as 'P'
        if char in '。，、：；？！（）""''': #will treat 《》 as "other" because it they only appear in titles and I have tagged the titile differently
            return "Chinese punctuation"
    
    return "Other"

# Test the function
test_chars = chars
tags = '<>|'
non_chinese = []
for char in test_chars:
    if  classify_chinese_char(char) == "Other":
        non_chinese.append(char)

for nc in non_chinese:
    if nc not in tags:  #don't remove tags from character list
        text = text.replace(nc, "")  # removes some supurious non chinese characters in the data 
chars = sorted(list(set(text))) #cleaned up character list

chars_save = ''.join(chars)
with open('all_characters.txt','w') as fc:
    fc.write(chars_save)

len(chars)
text[:300]
chars
non_chinese = []
for char in chars:
    if  classify_chinese_char(char) == "Other":
        non_chinese.append(char)
print(non_chinese) #check if clean up is complete, should give only the tags <>|

['<', '>', '|']


In [34]:

punct =[p for p in chars if classify_chinese_char(p)== "Chinese punctuation" or p in '<>|'] #later use it to exclude puct + character type of pairs in BPE
punct, len(punct)


(['<', '>', '|', '、', '。', '！', '（', '）', '，', '：', '；', '？'], 12)

In [30]:
punct =[p for p in chars if classify_chinese_char(p)== "Chinese punctuation" or p in '<>|'] #later use it to exclude puct + character type of pairs in BPE

chars_cp = chars.copy()
for c in punct: #remove punctuationsm  so that I'll assciate them with smallest numbers 0,1,2, 3....
    chars_cp.remove(c)  

#print(''.join(chars)


stoi = {}
for i, ch in enumerate(punct):
    stoi[ch] = i
for i, ch in enumerate(chars_cp):
    stoi[ch] = i + len(punct)

itos={}
for i, ch in enumerate(punct):
    itos[i] = ch
for i, ch in enumerate(chars_cp):
    itos[i + len(punct)] = ch

encode_temp = lambda s: [stoi[c] for c in s]
decode_temp = lambda l: ''.join([itos[i]for i in l])

def get_stats_nonpunct(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        if (pair[0] > len(punct) -1 and pair[1] > len(punct) -1):
            counts[pair] = counts.get(pair,0) + 1
    return counts

    
def merge(ids, pair, idx):
  # in the list of characters (ids), replace all consecutive occurences of pair with the new token idx, if both are characters
    newids = []
    i = 0
    while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids


# stats = get_stats(text)
# print(sorted(((v, k) for k ,v in stats.items()),reverse = True)[:400])
# print(f'total number = {len(stats)}')
# ---
#vocab_size =len(chars) + 100 # the desired final vocabulary size




In [32]:
ids = list(encode_temp(text)) # copy so we don't destroy the original list

# Get a sense of the most common words
# stats = get_stats_nonpunct(ids)
# sorted_stats = dict(sorted(stats.items(), key=lambda item: item[1], reverse=True))
# # for (p1, p2) in stats:
# for i, (key, value) in enumerate(sorted_stats.items()):
#     if i >= 400:
#         break
#     print(f"words:{itos[key[0]]+itos[key[1]]}| count: {value}")

vocab_size = 8000
num_merges = vocab_size - len(chars)

# merges = {} # (int, int) -> int
# for i in range(num_merges):
#     stats = get_stats_nonpunct(ids)
#     pair = max(stats, key=stats.get)
#     idx = len(chars) + i
#     print(f"merging {pair} into a new token {idx}")
#     # if i +1 % 50 == 0:
#     #     print(f'merged {i +1} pairs')
#     ids = merge(ids, pair, idx)
#     merges[pair] = idx

#save merges because creating 500 pairs takes a while
#import pickle

## Saving
# with open('tang_pair_encodings.pkl', 'wb') as file:
#     pickle.dump(merges, file)

# Loading
import pickle
with open('tang_pair_encodings.pkl', 'rb') as file:
    merges = pickle.load(file)


#len(chars) == len(stoi)
#itos
#Expand the itos and stoi dictionaries to include pairs
for (p0, p1), idx in merges.items():
    itos[idx] = itos[p0] + itos[p1]
for idx, ch in itos.items():
    stoi[ch] = idx
    
import pickle

with open('itos.pkl', 'wb') as file:
    pickle.dump(itos, file)
with open('stoi.pkl', 'wb') as file:
    pickle.dump(stoi, file)

# with open('stoi.pkl', 'rb') as file:
#     merges = pickle.load(file)


def encode(text):
    tokens = list(encode_temp(text))
    while len(tokens) >= 2:
        stats = get_stats_nonpunct(tokens)
        pair = min(stats, key= lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

def decode(ids):
    text = "".join(itos[idx] for idx in ids)
    return text


In [8]:
#import time

# start_time = time.time()

# enc_data = encode(text)  

# end_time = time.time()
# print(end_time - start_time)    #takes about 700 secs to encode, better save to a json file
# import json
# with open('tang_poems_pair_encoded.json', 'w') as file:
#     json.dump(enc_data, file)
# import json    
# with open('tang_poems_pair_encoded.json', 'r') as file:
#         enc_data =  json.load(file)


In [9]:

import os
import torch
import torch.nn as nn
from torch.nn import functional as F
#import time



#else 'mps' if torch.backends.mps.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

data = torch.tensor(enc_data, dtype = torch.long).to(device)
n = int(0.9 *len(data))
train_data= data[:n]
val_data = data[n:]

torch.manual_seed(13997)
batch_size = 96
block_size = 500

vocab_size = len(itos)
n_embed = 216
num_heads = 6
dropout = 0.1
n_layers= 6
eval_iters = 100


def get_batch(split):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)            
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def estimate_val_loss(model):
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch('val')
        X, Y = X.to(device), Y.to(device)            
        logits, loss = model(X,Y)
        losses[k] = loss.item()
    val_loss = losses.mean()
    model.train()
    return val_loss
            
# class Head(nn.Module):
#     def __init__(self, head_size):
#         super().__init__()
#         self.key = nn.Linear(n_embed, head_size, bias = False)
#         self.query = nn.Linear(n_embed, head_size, bias = False)
#         self.value = nn.Linear(n_embed, head_size, bias = False)

#         self.register_buffer('tril',torch.tril(torch.ones(block_size, block_size)))
#         self.dropout= nn.Dropout(dropout)

#     def forward(self, x):
#         B,T,C = x.shape
            
#         k = self.key(x) #(B,T, head_size)
#         q = self.query(x)
#         v = self.value(x)
        
#         wei  = q @ k.transpose(-2,-1) * C**-0.5 #transpose along the last two dimensions, i.e. T and head_size 
#                                                         #(dot product sums over head_size indices)
#                                         # (B,T, head_size) @  (B, head_size, T) -> (B,T, T)
#         tril = torch.tril(torch.ones(T,T))
#         wei = wei.masked_fill(tril == 0, float('-inf') )
#         wei = F.softmax(wei, dim=-1)
#         wei = self.dropout(wei)
#         out = wei @ v
        
#         return out
        
class Head(nn.Module):#modified from above so that 'tril' tensor is always on the same device
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ v
        return out        
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
                    nn.Linear(n_embed, 4*n_embed),
                    nn.ReLU(),
                    nn.Linear(4*n_embed, n_embed),
                    nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self,n_embed, num_heads):
        super().__init__()
        head_size = n_embed // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size) #sa = self attention
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)  #LayerNorm also contains trainable parameters

    def forward(self, x):
        x = x + self.sa( self.ln1(x) ) #skip/residual connections
        x = x + self.ffwd(  self.ln2(x)  )
        return x


class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
                    *[Block(n_embed, num_heads ) for _ in range(n_layers)],
                    nn.LayerNorm(n_embed),
        )
        #self.sa_head = MultiHeadAttention(4, n_embed//4)
        self.ffwd = FeedForward(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):

        B, T = idx.shape
        tok_emd  = self.token_embedding_table(idx)
        pos_emd = self.position_embedding_table(torch.arange(T, device = device))
        x= tok_emd + pos_emd
        x = self.blocks(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)
        
        if targets == None:
            loss = None
        else:
    
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
        return idx
        
    def generate_one_poem(self):
        idx =  torch.zeros((1, 1), dtype=torch.long, device=device)
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
            if idx_next.item() == 1:
                break
        return idx



Using device: mps


In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

m = BigramLanguageModel().to(device)

#model_path = 'nano_tang_poem_layer6_context40_nebd64_nhead4.pt' 
# model_path = 'nano_tang_poem_layer6_context80_nebd64_nhead4.pt' 
# model_path = 'nano_tang_poem_layer8_context80_nebd64_nhead4.pt' # 1423780 trainable parameters
# model_path = 'nano_tang_poem_layer10_context80_nebd64_nhead4.pt' # 1523364 trainable parameters.
#model_path = 'nano_tang_poem_layer10_context80_nebd96_nhead8.pt' #2674436 trainable parameters

#model_path = 'nano_tang_poem_BPE_layer10_context80_nebd96_nhead8.pt' #2741600 trainable parameters
# model_path = 'nano_tang_poem_BPE_layer14_context80_nebd96_nhead8.pt' #3187808 trainable parameters # locally this costs 2.9GB MEM, 97%GPU and 87%CPU
# model_path = 'nano_tang_poem_BPE_layer10_context500_nebd180_nhead6.pt' #7,144,460 trainable parameters 
# model_path = 'nano_tang_poem_BPE_layer4_context500_nebd252_nhead6.pt' #7,734,068 trainable parameters 
# model_path = 'nano_tang_poem_BPE_layer5_context500_nebd300_nhead6.pt' #11,095,100 trainable parameters 
# model_path = 'nano_tang_poem_BPE_layer4_context700_nebd384_nhead6.pt' #14,696,384 trainable parameters 
model_path = 'nano_tang_poem_BPE_layer6_context500_nebd216_nhead6.pt' #7,318,952
if os.path.exists(model_path):
    m.load_state_dict(torch.load(model_path, map_location=device))
    
num_params = count_parameters(m)
print(f"The model has {num_params} trainable parameters.")

The model has 7318952 trainable parameters.


In [14]:
initial_loss = torch.log(torch.tensor(vocab_size)).item()
train_loss_list = [initial_loss]
val_loss_list = [initial_loss]

In [35]:
max_iters =6001
eval_interval = 300
learning_rate = 1* 1e-3
m.train()
optimizer = torch.optim.AdamW(m.parameters(), lr =learning_rate)
optimizer.zero_grad(set_to_none=True)




for steps in range(max_iters):
    xb, yb = get_batch('train')

    if steps % eval_interval ==0:
        train_loss = estimate_loss(m)['train']
        with open(f'train_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{train_loss}\n")
            
        val_loss = estimate_loss(m)['val']
        with open(f'val_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{val_loss}\n")
            
        if val_loss < min(val_loss_list):
            torch.save(m.state_dict(), model_path)
        
        train_loss_list.append(train_loss.item())
        val_loss_list.append(val_loss.item())
        
        
        print(steps, f'train loss: {train_loss} | validation loss: {val_loss}')
        
    logits, loss = m(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if val_loss < 4.6:
            break
    

max_iters =6001
eval_interval = 50
learning_rate = 1* 1e-4
m.train()
optimizer = torch.optim.AdamW(m.parameters(), lr =learning_rate)
optimizer.zero_grad(set_to_none=True)


for steps in range(max_iters):
    xb, yb = get_batch('train')

    if steps % eval_interval ==0:
        train_loss = estimate_loss(m)['train']
        with open(f'train_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{train_loss}\n")
            
        val_loss = estimate_loss(m)['val']
        with open(f'val_loss_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.txt', 'a') as file:
            file.write(f"{val_loss}\n")
            
        if val_loss < min(val_loss_list):
            torch.save(m.state_dict(), model_path)
        
        train_loss_list.append(train_loss.item())
        val_loss_list.append(val_loss.item())
        
        
        print(steps, f'train loss: {train_loss} | validation loss: {val_loss}')
    logits, loss = m(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


In [11]:
m.eval()
for _ in range(20):
    print(decode(m.generate_one_poem()[0].tolist()))

<昭君怨|六宮初入玉關雲，欲作團圓逐洛神。隴水未分迷遠水，關城初落燕巢春。金魚墮處曾留恨，秋草遮終不惜春。莫向南宮明月夜，冷雲依舊碧桃園。>
<禦制三月三日來|六稼寒還散，重陽別未齊。樽罍新滿醉，物色晚鮮飄。楊柳堪惆悵，長多舊雪題。>
<玉蕊萄|北渚洗瓊樹，傍池搖玉川。君采采樵香，鶴巢小洞蓮。濯纓不背面，自此自憐天。藕葉侵潭暖，樹枝光悄然。期君杜朝客，知我獨匆然。>
<陪劉五貺新詩十二首：強健子謠|群舷言，九牛而冠。賢自逶氏，寫過殷王？其動寥廓，孰雲足。雖非中央，其律有德。其隋無疆，君子安足。天下無德，我庶大志。持湯太三，畢張朱鷺。王公五臣，九華不供。舞環振絕，歌上下。>
<送荊少府赴任|又將杯酒薄，複似朔方急。晝務值春深，遲時喜遙夜。>
<舟中答韋祭酒一詠：懷古得其具詩：公|明前蔽千里，室中獨高樓。仿佛不知處，雲中疑不收。折碑為碎綠，剝翦照狼頭。為長河水底，依舊有終秋。不能披羽節，豈意清泠流。>
<賦兄弟|雙旌汾玉管，一僕奏宸聊。波上桃源綠，煙中鳥道黃。井沾昏楚斷，閣接麗譙涼。想憶歡娛日，三年過故鄉。>
<嘲唐昌宰宅，每篇|項衣三畝地，工得一人交。庾監標天下，狂歌跨世人。馬分關道直，鷹出樹陰緣。話過科鬥地，詩以太虛天。觀宇鶉初合，驚風鶴自圓。田居期社稷，公事舍秋田。>
<送蕭判官|笛歸歸湘水，東楚江南松。風流楚妃怨，千里千里隔。落葉複江陰，蒼然斑竹林。翻令楚遊客，八月長相思。>
<望雷州一望|沾景似攸遊，臨高無定幽。山殊怏雲路，雲晦高槐丘。貝闕亭氣深，歲陰林景幽。豳歌若不豫，千里遊棹舟。>
<觀林寬和元相公領雪|清曉禁暄時，飛風助降木。寥寥珍重韻，渺渺勁如絲。有美尊前人，無須柳即枝。因知有魚賞，猶未山中期。>
<發且懶踵淝下|江村旋風不可到，年來五月謫和戎。行客已多逢豔色，故鄉俄固是枯翁。瑤台共鳥連環影，馬渡沾衣去拂虹。應是往還誰得見，月明月下更南東。>
<同諸公山池雪|東風落芳乘，搖豔吐繁英。歷歷凝層碧，參差排太清。亭園遽映川，翹翳初凝晴。幽嵐尚可靜，露松乍應生。尚畏惟願盡，瀾幽由所精。>
<賦得重載石頭|鳥寄人情代物長，劍纓對舞洛陽宮。胭脂競刺香消盡，愁雨還知惜歲寒。>
<神仙|穆王長送武陵遊，魏武遙驂鸞鳳樓。玉輦先開行子道，青門共許醉忠籌。宴餘新詠憐前事，風裏清哀下舊愁。池照沙平胡雁舞，女班沉沉小苑遊。閑稱二三林下約，不入九衢遊俠遊。隱隱留三山月月，青山長照