In [2]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import unicodedata

#import time

	
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
import unicodedata
import pickle

#load decoding dictionaries
with open('itos.pkl', 'rb') as file:
    itos = pickle.load(file)


#define function that decodes numbers to texts
def decode(ids):
    text = "".join(itos[idx] for idx in ids)
    return text

    

def get_batch(split):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)            
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def estimate_val_loss(model):
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch('val')
        X, Y = X.to(device), Y.to(device)            
        logits, loss = model(X,Y)
        losses[k] = loss.item()
    val_loss = losses.mean()
    model.train()
    return val_loss
            

        
class Head(nn.Module):#modified from above so that 'tril' tensor is always on the same device
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ v
        return out        
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout= nn.Dropout(dropout)
        
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class FeedForward(nn.Module):

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
                    nn.Linear(n_embed, 4*n_embed),
                    nn.ReLU(),
                    nn.Linear(4*n_embed, n_embed),
                    nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self,n_embed, num_heads):
        super().__init__()
        head_size = n_embed // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size) #sa = self attention
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)  

    def forward(self, x):
        x = x + self.sa( self.ln1(x) ) #skip/residual connections
        x = x + self.ffwd(  self.ln2(x)  )
        return x


class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        #self.position_embedding_table = nn.Embedding(block_size, n_embed) #delete position embeddings
        self.blocks = nn.Sequential(
                    *[Block(n_embed, num_heads ) for _ in range(n_layers)],
                    nn.LayerNorm(n_embed),
        )
        self.ffwd = FeedForward(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets = None):

        B, T = idx.shape
        tok_emd  = self.token_embedding_table(idx)
        # pos_emd = self.position_embedding_table(torch.arange(T, device = device))  #delete position embeddings
        # x= tok_emd + pos_emd             #delete position embeddings
        x = tok_emd
        x = self.blocks(x)
        x = self.ffwd(x)
        logits = self.lm_head(x)
        
        if targets == None:
            loss = None
        else:
    
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
        return idx
        
    def generate_one_poem(self):
        idx =  torch.zeros((1, 1), dtype=torch.long, device=device)
        while True:
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim =1)
            if idx_next.item() == 1:
                break
        return idx

def find_poem_lengths(poem):
    poem_lens = []
    longest_poem_pos = None
    poem_len_holder = 0
    for char in poem:
        poem_len_holder += 1
        if char == '>':
            poem_lens.append(poem_len_holder)
            poem_len_holder = 0
    return poem_lens


        

In [4]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(13997)
batch_size = 96
block_size = 500

vocab_size = len(itos)
n_embed = 216
num_heads = 6
dropout = 0.1
n_layers= 8
eval_iters = 100



m = BigramLanguageModel().to(device)
num_params = count_parameters(m)

model_path = 'nano_tang_poem_without_pos_emb_layer8_context500_nebd216_nhead6.pt' #8147094 trainable parameters.

if os.path.exists(model_path):
    m.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
    print(f"Load existing model complete.")
else:
    print("Creat new model weights file")

print(f"The model has {num_params} trainable parameters.")
print(f"Embeding dimension = {n_embed},\nContext length = {block_size},\nnumber of heads per layer = {num_heads},\nnumber of layers = {n_layers}")

Using device: mps
Load existing model complete.
The model has 8147094 trainable parameters.
Embeding dimension = 216,
Context length = 500,
number of heads per layer = 6,
number of layers = 8


  return self.fget.__get__(instance, owner)()


In [5]:
seed = 0
torch.manual_seed(seed)
print(f'Seed = {seed}')
print(model_path)
m.eval()
for _ in range(10):
    print(decode(m.generate_one_poem()[0].tolist()))#loss at 4.54

Seed = 0
nano_tang_poem_without_pos_emb_layer8_context500_nebd216_nhead6.pt
<寒行台驛|秦樓聽警冬征，虞氏即令威。銅與雲光上，玉爐煙景浮。素姿當歲駟，天步佇光輝。>
<中秋夕人房，自使|白首荒城陌，風塵頓寂寥。野花迎古村落，日雨傍池塘。近竹風來晝，幽人夢覺遙。松聲猿叫一聲，宿鳥聽翻簫。杳杳滄溟淺，潺潺在檻寬。>
<賀李。殷堯湯|上下渚度秋山，繁花別楚鄉。天河浴行槳，槎枿憑將亡。松殿宜陵雨，關山獨鳥翔。萬家當甲第，雙堠似春香。會合齊平嶽，還將白首陽。>
<除妻重見宴駑駘二吳興|自是蘭山別，三台誰與雲。連天南面月，待婢禦衣雲。苑戍連雞劍，樓蘭斷柏尊。封侯天寵獻，見沐漢儀文。>
<楊柳陌|金履青嵐八兩坡，碧潭紅杏正開。陸機飛蓋寒花退，青榜高花暗霧飛。舊恨洛陽無女妒，暮塵西笑送青衣。芙蓉朵帶綠銀瓶，橫雉飛驚佩繡行。結綺搖歌自舞，秦檀尾乘鸞。香散何時到唐老，玉顏一度徹雙蛾。>
<過遠村人居|商郊急雨家村，江上時聞幾村。自惜不知春早，且將池上曲還。>
<曲歌行|楊葉依輕絮斜，水邊深夜夜吳家。隴草蕭條斜日月，離人牽斷翠娥花。遼東北寒與愁來，莫使秋風裏酒杯。三月江亭夜月，誰家越女正開花。說著春寒食路遙，燈微雨送殘花只。萬里暮吟吳畔葉，竟陵春雨辟寒衣。>
<榮席|美女嬌初動鬢圓，時時調少莫妍嬌。>
<醉輕|酌美金罍催滿安，熟沈玳瑁帳長鏗。更堪縹緲繁須足，莫忘公才有社床空。>
<句|瘴何賈生泊，愁容不自眠。有帆琴下過，吹管動湖邊。>
