In [None]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
import sys

sys.path.append(r"../../")
from circe.models.LightningClassifier import LightningClassifier

cfg_model = OmegaConf.load('../training/conf/model/hf-gpt.yaml')
cfg_data = OmegaConf.load('../training/conf/data/text-gpt.yaml')

ckpt_path = "../../../models-hfGPT-shakespeare/lightning_logs/version_3/checkpoints/epoch=31-step=15712.ckpt"
model = LightningClassifier(cfg=cfg_model)
model.configure_sharded_model()
model.load_state_dict(torch.load(ckpt_path)["state_dict"])
model.eval()

In [None]:
out_dir = 'out' # ignored if init_from is not 'resume'
start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = "cuda"
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
meta_path = "../../../datasets/shakespeare_char/meta.pkl"

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
import pickle
from contextlib import nullcontext

with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
    stoi, itos = meta['stoi'], meta['itos']
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: ''.join([itos[i] for i in l])

start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=next(model.parameters()).device)[None, ...])
# run generation
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=torch.float16)
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
            print(decode(y[0].tolist()))
            print('---------------')
