In [7]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import unicodedata
import pickle

#load decoding dictionaries
with open('itos.pkl', 'rb') as file:
    itos = pickle.load(file)


#define function that decodes numbers to texts
def decode(ids):
    text = "".join(itos[idx] for idx in ids)
    return text




def get_batch(split):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)            
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def estimate_val_loss(model):
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch('val')
        X, Y = X.to(device), Y.to(device)            
        logits, loss = model(X,Y)
        losses[k] = loss.item()
    val_loss = losses.mean()
    model.train()
    return val_loss
            

        
class Head(nn.Module):#modified from above so that 'tril' tensor is always on the same device
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ v
        return out        
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
                    nn.Linear(n_embed, 4*n_embed),
                    nn.ReLU(),
                    nn.Linear(4*n_embed, n_embed),
                    nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self,n_embed, num_heads):
        super().__init__()
        head_size = n_embed // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size) #sa = self attention
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)  

    def forward(self, x):
        x = x + self.sa( self.ln1(x) ) #skip/residual connections
        x = x + self.ffwd(  self.ln2(x)  )
        return x


class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
                    *[Block(n_embed, num_heads ) for _ in range(n_layers)],
                    nn.LayerNorm(n_embed),
        )
        self.ffwd = FeedForward(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):

        B, T = idx.shape
        tok_emd  = self.token_embedding_table(idx)
        pos_emd = self.position_embedding_table(torch.arange(T, device = device))
        x= tok_emd + pos_emd
        x = self.blocks(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)
        
        if targets == None:
            loss = None
        else:
    
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
        return idx
        
    def generate_one_poem(self):
        idx =  torch.zeros((1, 1), dtype=torch.long, device=device)
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
            if idx_next.item() == 1:
                break
        return idx

In [11]:
block_size = 500
vocab_size = len(itos)
n_embed = 216
num_heads = 6
n_layers= 5


device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

m = BigramLanguageModel().to(device)
num_params = count_parameters(m)

# model_path = 'nano_tang_poem_layer8_context80_nebd64_nhead4.pt' # 1423780 trainable parameters
# model_path = 'nano_tang_poem_layer10_context80_nebd64_nhead4.pt' # 1,523,364 trainable parameters.
# model_path = 'nano_tang_poem_layer10_context80_nebd96_nhead8.pt' #2,674,436 trainable parameters
# model_path = 'nano_tang_poem_layer6_context500_nebd216_nhead6.pt' #7,131,030 trainable parameters
# model_path = 'nano_tang_poem_layer6_context500_nebd252_nhead6.pt' #9,044,034 trainable parameters
# model_path = 'nano_tang_poem_layer7_context500_nebd252_nhead6.pt' #9,808,602trainable parameters
# model_path = 'nano_tang_poem_layer7_context500_nebd300_nhead6.pt' #13,000,266 trainable parameters
# model_path = 'nano_tang_poem_layer12_context500_nebd252_nhead6.pt' #13,631,442 trainable parameters.
# model_path = 'nano_tang_poem_layer4_context500_nebd216_nhead6.pt'  #6,006,966 trainable parameters.
# model_path = 'nano_tang_poem_layer5_context500_nebd216_nhead6.pt'  #6,568,998 trainable parameters.
# for inference:
model_path = f'nano_tang_poem_layer{n_layers}_context{block_size}_nebd{n_embed}_nhead{num_heads}.pt'

if os.path.exists(model_path):
    m.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    print(f"Load existing model complete.")
else:
    print("Model file does not exist")

print(f"The model has {num_params} trainable parameters.")
print(f"Embeding dimension = {n_embed},\nContext length = {block_size},\nnumber of heads per layer = {num_heads},\nnumber of layers = {n_layers}")

Using device: mps
Load existing model complete.
The model has 6568998 trainable parameters.
Embeding dimension = 216,
Context length = 500,
number of heads per layer = 6,
number of layers = 5


  return self.fget.__get__(instance, owner)()


In [12]:
seed = 12312221
torch.manual_seed(seed)
print(f'Seed = {seed}')
print(f"Using model:  {model_path}")
m.eval()
for _ in range(10):
    print(decode(m.generate_one_poem()[0].tolist()))#loss at 4.53

Seed = 12312221
Using model:  nano_tang_poem_layer5_context500_nebd216_nhead6.pt
<送靈徹明王府|耿耿霞山裏，蔥林紫氣微。樓臺淩夕氣，樓作度秋暉。餘雨生秋晚，殘蟬向夕稀。離情徒喜遇，雅思滿前飛。>
<贈韋指封南歸|片帆臨太白，潮水聚新泉。雪散江雲遠，雲飛楚草連。紫霞雲夢裏，滄海客行前。白日機心靜，青山磬帶寒。仍為碧海客，俱為春州田。>
<納思|曉引迸苔機，雲到架蓬重。桃花清淺景，嶰澗綠縈融。乍拂文雲水，低添舞鶴峰。夜涼前後望，宵吹掃還空。經我鶴指鳳，驅馳雉北風。歌謠燠塵俗，飄落響寒風。白露臨寒景，紅霞帶暖空。蕭疏嘯蕭瑟，淅瀝怨生紅。擁砌聯翩動，兼軒罷卷空。早涼身自喜，落日志難窮。歡賞追歡悅，良辰暇志通。豫遊如未得，還會在歌蓬。>
<陪旻公白石屏|伏寺請昨日，八天生甑亡。丹梯我在高，世業常在目。已憐江海舟，千年夏江使。想我二十年，從此便堪異。龍才浮雲車，歲暮涼風利。時施轉蓬龍，玉峰生鬼魅。糞土蹋古木，甘心資寒暑。豐萌何其理，疏俗多精縮。空遷嘉辛裏，鑿石忻無歲。下山順春流，餘雪覆深竹。松蘿起枯木，偶與澗穀廣。靈液泛修修，直須上神怪。不知煩世趣，日月成神格。琴沉日虛靜，齋潔如氛昧。仿佛放未至，焚香生所憶。桃源若有人，蹇步遂無事。方隨化城會，此外唯歸趣。>
<關中作|早得稽山不發吹，此生唯有舊名時。晚來楊柳連中老，積雪朦朧在小兒。今日幾宵蘭若故，殷勤更賦訪經時。>
<狄昭中相公樓歌|滕公閣夏西壇在，插竹燈陰一望清。江路春還小杏禦，村橋斜日碧芙蓉。南庭不作清朝賞，金殿池台便引行。別後吟聲在書牖，酒醒相思倍四更。>
<四明宮之詩：草草磯|梨花片葉滿危汀，莫使朝來不倚春。王母莫留牽意處，懶教明日到深春。>
<奉和崔相公|帝山推磬警降祥，豐沛尊陪瑞最祥。自昔齊神堯舜合，長長漢轉禦雞場。>
<送刑賢|詩中萬首諭，覺酒各流杯。竹影應難扣，松聲只自開。誰知戎客意，別路石樓臺。>
<中馬自河|亞尋花下草頭看，百六州深螮蝀寬。草上橋頭蘇子盛，渭中門裏幕中寒。松風割馬朱錢點，松雪粘花半眼看。>
