In [44]:
import os
import json
import pickle
from contextlib import nullcontext
from easydict import EasyDict

import torch
import torch.nn.functional as F

import sys
sys.path.append('../')
from model import GPT

In [45]:
with open('../config/sample.json', 'r') as f:
    sample_config = json.load(f)
sample_config = EasyDict(sample_config)

with open('../config/model.json', 'r') as f:
    model_config = json.load(f)
model_config = EasyDict(model_config)

In [46]:
model_config

{'batch_size': 12,
 'block_size': 128,
 'vocab_size': 12966,
 'n_embd': 384,
 'n_head': 12,
 'n_layer': 12,
 'dropout': 0.0}

In [47]:
sample_config

{'init_from': 'resume',
 'start': '\n',
 'num_samples': 10,
 'max_new_tokens': 500,
 'temperature': 0.95,
 'top_k': 200,
 'seed': 1337}

In [48]:
torch.manual_seed(sample_config.seed)
torch.cuda.manual_seed(sample_config.seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [49]:
# if sample_config.init_from == 'resume':
checkpoint = torch.load('../params/chipogen_model26.0M.pth', map_location=device)
gpt_model = GPT(model_config)
state_dict = checkpoint['model']
gpt_model.load_state_dict(state_dict)

number of parameters: 26.22M


<All keys matched successfully>

In [50]:
gpt_model.to(device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(12966, 384)
    (wpe): Embedding(128, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (attn): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): SelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=12966, bias=False)
)

In [51]:
with open('../data/tokens.pkl', 'rb') as f:
    tokens = pickle.load(f)

stoi = { w:i for i,w in enumerate(tokens)}
itos = { i:w for i,w in enumerate(tokens)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [52]:
start_ids = encode(sample_config.start)
x = torch.tensor(start_ids, dtype=torch.long, device=device).view(-1, 1)
x

tensor([[0]], device='cuda:0')

In [53]:
gpt_model.eval()

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(12966, 384)
    (wpe): Embedding(128, 384)
    (drop): Dropout(p=0.0, inplace=False)
    (attn): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm()
        (attn): SelfAttention(
          (c_attn): Linear(in_features=384, out_features=1152, bias=False)
          (c_proj): Linear(in_features=384, out_features=384, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=384, out_features=1536, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1536, out_features=384, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=384, out_features=12966, bias=False)
)

In [54]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature, top_k):
    
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx[:, -model_config.block_size:]
        # forward the model to get the logits
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature # pluck the logits at the final step and scale by desired temperature
        # optionally crop the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('inf')
        
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

In [55]:
with torch.no_grad():
    with ctx:
        y = generate(gpt_model, x, 
                         sample_config.max_new_tokens, 
                         sample_config.temperature,
                         sample_config.top_k)

In [56]:
ids = y[0].tolist()

In [57]:
print(decode(ids))



爱此岩霏雪，深惭吏隐诗。
相逢非论地，一别近谁期。
月朗江村暮，风清道路迟。
明朝问流水，回棹可来兹。

朝退时应到，吟窗独对僧。
客愁真滚滚，世事莫懵腾。
过眼朝朝急，论心事事增。
浮山只如此，何处不堪乘。

久废风流兴，虚斋复宴无。
幽人诗句好，小室酒垆孤。
书帙三杯湿，香瓶两柄枯。
此时神胜在，不必问还无。

南来忽作数峰云，更问清泉与浅濆。
白月长松多在下，青山浊酒不嫌醺。
楼头晓市人烟散，竹里西峰曙色分。
何处玉泉堪借问，几回斜日忆劳君。

仙人住在海云乡，药石风流一钓航。
太古易传天上箓，清秋同过石头庄。
山中猿鹤频来往，江上烟霞独渺茫。
知有山川称逸士，应题题字寄柴桑。

十年高卧斗牛宫，相对山川兴趣同。
欲返风尘成未得，相期水月是成功。
中流极目知多少，曲径通村有几重。
自笑老来乘醉去，肯令尘世混游龙。

使君频下使关中，宣武归来语便雄。
花外几多高士宅，尘中一舸太湖风。
青山已许凭栏绿，绿树应愁夹洞红。
我有清琴和不断，愿将明月破西东。

野叟家风不可寻，故应丘壑尚知音。
年来未尽黄金价，老去何由赤帜心。
时向此中思得句，客逢明月作知音。
道人坐对山风起，更把离觞更
