In [5]:
import os
from contextlib import nullcontext
import torch
from model import GPTConfig, GPT
from bertviz import head_view
from dataset import Converter, LMDataset

# set random seed for reproducibility
seed = 2024
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

#################################################
# 
model_name = 'mygpt'
ckpt_path = 'workdirs/quansongci'
data_root = 'data/quansongci'
vis_text_path = 'data/vis/vis_1.txt'
#################################################

device = 'cpu'

dataset = LMDataset(data_root, 'train')
converter = Converter(dataset.stoi, dataset.itos)


with open(vis_text_path, 'r', encoding='utf-8') as f:
    start = f.read()
start_ids = converter.single_encode(start)
start_texts = [c for c in start]
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
print(f"Input texts: {start}")

Input texts: +++如梦令
昨夜雨疏风骤。浓睡不消残酒。试问卷帘人，却道海棠依旧。知否。知否。应是绿肥红瘦。


In [7]:
# model
dtype = 'float16' # 'float32' or 'bfloat16' or 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device == 'cpu' else torch.autocast(device_type=device, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = os.path.join(ckpt_path, 'best.pth')
print("loading model params from %s"%ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig[model_name]
if 'model_args' in checkpoint:
    gptconf = checkpoint['model_args']
model = GPT(**gptconf)
state_dict = checkpoint['state_dict']
model.load_state_dict(state_dict)

model.eval()
model.to(device)

# run generation
with torch.no_grad():
    with ctx:
        _, attn_weights = model(x)

head_view(attn_weights, start_texts)

loading model params from workdirs/quansongci\best.pth


<IPython.core.display.Javascript object>