In [19]:
import os
import json
import pickle
from contextlib import nullcontext
from easydict import EasyDict

import torch
import torch.nn.functional as F

import sys
sys.path.append('../')
from model import GPT

In [20]:
with open('../config/sample.json', 'r') as f:
    sample_config = json.load(f)
sample_config = EasyDict(sample_config)

with open('../config/model.json', 'r') as f:
    model_config = json.load(f)
model_config = EasyDict(model_config)

In [21]:
model_config

{'batch_size': 12,
 'block_size': 128,
 'vocab_size': 12966,
 'n_embd': 384,
 'n_head': 12,
 'n_layer': 12,
 'dropout': 0.0}

In [22]:
sample_config

{'init_from': 'resume',
 'start': '春',
 'num_samples': 10,
 'max_new_tokens': 500,
 'temperature': 0.95,
 'top_k': 200,
 'seed': 1337}

In [23]:
torch.manual_seed(sample_config.seed)
torch.cuda.manual_seed(sample_config.seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [24]:
# if sample_config.init_from == 'resume':
checkpoint = torch.load('../params/chipogen_model26.0M.pth', map_location=device)
gpt_model = GPT(model_config)
state_dict = checkpoint['model']
gpt_model.load_state_dict(state_dict)

number of parameters: 26.22M


<All keys matched successfully>

In [25]:
gpt_model.to(device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(12966, 384)
    (wpe): Embedding(128, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (attn): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): SelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=12966, bias=False)
)

In [26]:
with open('../data/tokens.pkl', 'rb') as f:
    tokens = pickle.load(f)

stoi = { w:i for i,w in enumerate(tokens)}
itos = { i:w for i,w in enumerate(tokens)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [28]:
start_ids = encode(sample_config.start)
x = torch.tensor(start_ids, dtype=torch.long, device=device).view(-1, 1)
x

tensor([[4899]], device='cuda:0')

In [29]:
gpt_model.eval()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(12966, 384)
    (wpe): Embedding(128, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (attn): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): SelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=12966, bias=False)
)

In [30]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature, top_k):
    
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx[:, -model_config.block_size:]
        # forward the model to get the logits
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature # pluck the logits at the final step and scale by desired temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('inf')
        
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

In [31]:
with torch.no_grad():
    with ctx:
        y = generate(gpt_model, x, 
                         sample_config.max_new_tokens, 
                         sample_config.temperature,
                         sample_config.top_k)

In [32]:
ids = y[0].tolist()

In [33]:
print(decode(ids))

春又一。
欲知花底闲，烟云想葱茜。
却恐花外闻，谁复在顷刻。

世间无几道，日月等浮烟。
此自看游戏，人间路百千。
千年吾未晓，一镜我难悬。
我是双珠泪，知心直共缘。

有所思，在绝代。
江山只在眼底。

天河独上星斗寒，银盘露湿黄金盘。
清光不似银河鹊，一寸琼浆掌上看。

晓起西窗冷，云屏一望高。
玉函香冷月还在，无柰冷香生碧桃。

野田禾果绿纷纷，花坞扶疏竹笕分。
野老手中看竹去，小儿蓑笠惯翻文。

庭西月地更清凉，谁识仙家种玉堂。
莫傍花阴寻伴侣，老仙曾住董仙乡。

海风吹作落花香，雪片初飞墨未乾。
天上行人新有约，山中明月似无眠。

云满空山鹤下蓑，野情未觉客愁多。
雪中相对一樽酒，花外高吟两鬓皤。

山中来往尽如梦，何处相逢未得归。
水似白云横野坐，江如飞雪放春飞。

楼阁峰边湖一环，坐来乘醉鬓堪斑。
故人携手观天乐，会有神仙在世间。

松陵古道秋云昏，几度高秋把菊蹲。
风扫云烟閒似我，醉吟泉石冷于门。

溪流曲折转萦回，绿树阴阴鸟雀哀。
地近欲穷平远地，山横不断旧时台。

一溪清浅静无波，水国鱼山共浅莎。
不见清歌明月夜，始知秋色近如何。

绿树阴阴静乍寒，一轮秋色在栏杆。

